1use std::env;
2
3use anyhow::Result;
4use axum::http::{HeaderMap, HeaderValue};
5use cog_core::http::WebhookEvent;
6use reqwest::Client;
7use url::Url;
8
9use crate::prediction::{Prediction, ResponseHelpers};
10
11pub struct WebhookSender {
12 client: Client,
13}
14
15impl WebhookSender {
16 pub fn new() -> Result<Self> {
17 let mut headers = HeaderMap::new();
18 let client = Client::builder();
19
20 if let Ok(token) = env::var("WEBHOOK_AUTH_TOKEN") {
21 let mut authorization = HeaderValue::from_str(&format!("Bearer {token}"))?;
22 authorization.set_sensitive(true);
23 headers.insert("Authorization", authorization);
24 }
25
26 Ok(Self {
27 client: client
28 .user_agent(format!("cog-worker/{}", env!("CARGO_PKG_VERSION")))
29 .default_headers(headers)
30 .build()?,
31 })
32 }
33
34 pub async fn starting(&self, prediction: &Prediction) -> Result<()> {
35 let request = prediction.request.clone().unwrap();
36 if !Self::should_send(&request, WebhookEvent::Start) {
37 return Ok(());
38 }
39
40 self.send(
41 request.webhook.clone().unwrap(),
42 cog_core::http::Response::starting(prediction.id.clone(), request),
43 )
44 .await?;
45
46 Ok(())
47 }
48
49 pub async fn finished(
50 &self,
51 prediction: &Prediction,
52 response: cog_core::http::Response,
53 ) -> Result<()> {
54 let request = prediction.request.clone().unwrap();
55 if !Self::should_send(&request, WebhookEvent::Completed) {
56 return Ok(());
57 }
58
59 self.send(request.webhook.clone().unwrap(), response)
60 .await?;
61
62 Ok(())
63 }
64
65 fn should_send(req: &cog_core::http::Request, event: WebhookEvent) -> bool {
66 req.webhook.is_some()
67 && req
68 .webhook_event_filters
69 .as_ref()
70 .map_or(true, |filters| filters.contains(&event))
71 }
72
73 async fn send(
74 &self,
75 url: Url,
76 res: cog_core::http::Response,
77 ) -> Result<reqwest::Response, reqwest::Error> {
78 tracing::debug!("Sending webhook to {url}");
79 tracing::trace!("{res:?}");
80
81 self.client.post(url).json(&res).send().await
82 }
83}