cog_rust/
webhooks.rs

1use std::env;
2
3use anyhow::Result;
4use axum::http::{HeaderMap, HeaderValue};
5use cog_core::http::WebhookEvent;
6use reqwest::Client;
7use url::Url;
8
9use crate::prediction::{Prediction, ResponseHelpers};
10
11pub struct WebhookSender {
12	client: Client,
13}
14
15impl WebhookSender {
16	pub fn new() -> Result<Self> {
17		let mut headers = HeaderMap::new();
18		let client = Client::builder();
19
20		if let Ok(token) = env::var("WEBHOOK_AUTH_TOKEN") {
21			let mut authorization = HeaderValue::from_str(&format!("Bearer {token}"))?;
22			authorization.set_sensitive(true);
23			headers.insert("Authorization", authorization);
24		}
25
26		Ok(Self {
27			client: client
28				.user_agent(format!("cog-worker/{}", env!("CARGO_PKG_VERSION")))
29				.default_headers(headers)
30				.build()?,
31		})
32	}
33
34	pub async fn starting(&self, prediction: &Prediction) -> Result<()> {
35		let request = prediction.request.clone().unwrap();
36		if !Self::should_send(&request, WebhookEvent::Start) {
37			return Ok(());
38		}
39
40		self.send(
41			request.webhook.clone().unwrap(),
42			cog_core::http::Response::starting(prediction.id.clone(), request),
43		)
44		.await?;
45
46		Ok(())
47	}
48
49	pub async fn finished(
50		&self,
51		prediction: &Prediction,
52		response: cog_core::http::Response,
53	) -> Result<()> {
54		let request = prediction.request.clone().unwrap();
55		if !Self::should_send(&request, WebhookEvent::Completed) {
56			return Ok(());
57		}
58
59		self.send(request.webhook.clone().unwrap(), response)
60			.await?;
61
62		Ok(())
63	}
64
65	fn should_send(req: &cog_core::http::Request, event: WebhookEvent) -> bool {
66		req.webhook.is_some()
67			&& req
68				.webhook_event_filters
69				.as_ref()
70				.map_or(true, |filters| filters.contains(&event))
71	}
72
73	async fn send(
74		&self,
75		url: Url,
76		res: cog_core::http::Response,
77	) -> Result<reqwest::Response, reqwest::Error> {
78		tracing::debug!("Sending webhook to {url}");
79		tracing::trace!("{res:?}");
80
81		self.client.post(url).json(&res).send().await
82	}
83}