1use std::future::Future;
5use std::pin::Pin;
6
7use semver::Version;
8use serde::Deserialize;
9use tokio::sync::mpsc;
10
11use crate::error::SchedulerError;
12use crate::task::TaskHandler;
13
14const GITHUB_RELEASES_URL: &str = "https://api.github.com/repos/bug-ops/zeph/releases/latest";
15const MAX_RESPONSE_BYTES: usize = 64 * 1024;
16
17pub struct UpdateCheckHandler {
51 current_version: &'static str,
52 notify_tx: mpsc::Sender<String>,
53 http_client: reqwest::Client,
54 base_url: String,
56}
57
58#[derive(Deserialize)]
59struct ReleaseInfo {
60 tag_name: Option<String>,
61}
62
63impl UpdateCheckHandler {
64 #[must_use]
73 pub fn new(current_version: &'static str, notify_tx: mpsc::Sender<String>) -> Self {
74 let http_client = reqwest::Client::builder()
75 .timeout(std::time::Duration::from_secs(10))
76 .user_agent(format!("zeph/{current_version}"))
77 .build()
78 .expect("reqwest client builder should not fail with timeout and user_agent");
79 Self {
80 current_version,
81 notify_tx,
82 http_client,
83 base_url: GITHUB_RELEASES_URL.to_owned(),
84 }
85 }
86
87 #[must_use]
92 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
93 self.base_url = url.into();
94 self
95 }
96
97 fn newer_version(current: &str, tag_name: &str) -> Option<String> {
99 let remote_str = tag_name.trim_start_matches('v');
100 if remote_str.is_empty() {
101 return None;
102 }
103 let current_v = Version::parse(current).ok()?;
104 let remote_v = Version::parse(remote_str).ok()?;
105 if remote_v > current_v {
106 Some(remote_str.to_owned())
107 } else {
108 None
109 }
110 }
111}
112
113impl TaskHandler for UpdateCheckHandler {
114 fn execute(
115 &self,
116 _config: &serde_json::Value,
117 ) -> Pin<Box<dyn Future<Output = Result<(), SchedulerError>> + Send + '_>> {
118 Box::pin(async move {
119 let resp = self
120 .http_client
121 .get(&self.base_url)
122 .header("Accept", "application/vnd.github+json")
123 .send()
124 .await;
125
126 let resp = match resp {
127 Ok(r) => r,
128 Err(e) => {
129 tracing::warn!("update check request failed: {e}");
130 return Ok(());
131 }
132 };
133
134 if !resp.status().is_success() {
135 tracing::warn!("update check: HTTP {}", resp.status());
136 return Ok(());
137 }
138
139 let bytes = match resp.bytes().await {
140 Ok(b) => b,
141 Err(e) => {
142 tracing::warn!("update check: failed to read response body: {e}");
143 return Ok(());
144 }
145 };
146 if bytes.len() > MAX_RESPONSE_BYTES {
147 tracing::warn!(
148 "update check: response body too large ({} bytes), skipping",
149 bytes.len()
150 );
151 return Ok(());
152 }
153 let info: ReleaseInfo = match serde_json::from_slice(&bytes) {
154 Ok(v) => v,
155 Err(e) => {
156 tracing::warn!("update check response parse failed: {e}");
157 return Ok(());
158 }
159 };
160
161 let Some(tag_name) = info.tag_name else {
162 tracing::warn!("update check: missing tag_name in response");
163 return Ok(());
164 };
165
166 match Self::newer_version(self.current_version, &tag_name) {
167 Some(remote) => {
168 let msg = format!(
169 "New version available: v{remote} (current: v{}).\nUpdate: https://github.com/bug-ops/zeph/releases/tag/v{remote}",
170 self.current_version
171 );
172 tracing::debug!("update available: {remote}");
173 let _ = self.notify_tx.send(msg).await;
174 }
175 None => {
176 tracing::debug!(
177 current = self.current_version,
178 remote = tag_name,
179 "no update available"
180 );
181 }
182 }
183
184 Ok(())
185 })
186 }
187}
188
189#[cfg(test)]
190mod tests {
191 use wiremock::matchers::{method, path};
192 use wiremock::{Mock, MockServer, ResponseTemplate};
193
194 use super::*;
195
196 fn make_handler(
197 current_version: &'static str,
198 tx: mpsc::Sender<String>,
199 server_url: &str,
200 ) -> UpdateCheckHandler {
201 UpdateCheckHandler::new(current_version, tx).with_base_url(server_url)
202 }
203
204 #[test]
205 fn newer_version_detects_upgrade() {
206 assert_eq!(
207 UpdateCheckHandler::newer_version("0.11.0", "v0.12.0"),
208 Some("0.12.0".to_owned())
209 );
210 }
211
212 #[test]
213 fn newer_version_same_version_no_notify() {
214 assert_eq!(UpdateCheckHandler::newer_version("0.11.0", "v0.11.0"), None);
215 }
216
217 #[test]
218 fn newer_version_older_remote_no_notify() {
219 assert_eq!(UpdateCheckHandler::newer_version("0.11.0", "v0.10.0"), None);
220 }
221
222 #[test]
223 fn newer_version_strips_v_prefix() {
224 assert_eq!(
225 UpdateCheckHandler::newer_version("1.0.0", "v2.0.0"),
226 Some("2.0.0".to_owned())
227 );
228 assert_eq!(
229 UpdateCheckHandler::newer_version("1.0.0", "2.0.0"),
230 Some("2.0.0".to_owned())
231 );
232 }
233
234 #[test]
235 fn newer_version_invalid_current_returns_none() {
236 assert_eq!(
237 UpdateCheckHandler::newer_version("not-semver", "v1.0.0"),
238 None
239 );
240 }
241
242 #[test]
243 fn newer_version_invalid_remote_returns_none() {
244 assert_eq!(
245 UpdateCheckHandler::newer_version("1.0.0", "v-garbage"),
246 None
247 );
248 }
249
250 #[test]
251 fn newer_version_empty_tag_returns_none() {
252 assert_eq!(UpdateCheckHandler::newer_version("1.0.0", ""), None);
253 }
254
255 #[test]
259 fn newer_version_prerelease_is_notified() {
260 assert_eq!(
261 UpdateCheckHandler::newer_version("0.11.0", "v0.12.0-rc.1"),
262 Some("0.12.0-rc.1".to_owned())
263 );
264 }
265
266 #[tokio::test]
267 async fn test_execute_newer_version_sends_notification() {
268 let server = MockServer::start().await;
269 Mock::given(method("GET"))
270 .and(path("/"))
271 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
272 "tag_name": "v99.0.0"
273 })))
274 .mount(&server)
275 .await;
276
277 let (tx, mut rx) = mpsc::channel(1);
278 let handler = make_handler("0.11.0", tx, &server.uri());
279
280 handler
281 .execute(&serde_json::Value::Null)
282 .await
283 .expect("handler must not return an error");
284
285 let msg = rx.try_recv().expect("notification must be sent");
286 assert!(
287 msg.contains("99.0.0"),
288 "notification should mention new version"
289 );
290 assert!(
291 msg.contains("0.11.0"),
292 "notification should mention current version"
293 );
294 }
295
296 #[tokio::test]
297 async fn test_execute_same_version_no_notification() {
298 let server = MockServer::start().await;
299 Mock::given(method("GET"))
300 .and(path("/"))
301 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
302 "tag_name": "v0.11.0"
303 })))
304 .mount(&server)
305 .await;
306
307 let (tx, mut rx) = mpsc::channel(1);
308 let handler = make_handler("0.11.0", tx, &server.uri());
309
310 handler
311 .execute(&serde_json::Value::Null)
312 .await
313 .expect("handler must not return an error");
314
315 assert!(
316 rx.try_recv().is_err(),
317 "no notification expected for same version"
318 );
319 }
320
321 #[tokio::test]
322 async fn test_execute_http_404_no_notification_no_panic() {
323 let server = MockServer::start().await;
324 Mock::given(method("GET"))
325 .and(path("/"))
326 .respond_with(ResponseTemplate::new(404))
327 .mount(&server)
328 .await;
329
330 let (tx, mut rx) = mpsc::channel(1);
331 let handler = make_handler("0.11.0", tx, &server.uri());
332
333 let result = handler.execute(&serde_json::Value::Null).await;
334 assert!(result.is_ok(), "handler must return Ok on 404");
335 assert!(rx.try_recv().is_err(), "no notification expected on 404");
336 }
337
338 #[tokio::test]
339 async fn test_execute_http_429_rate_limit_graceful() {
340 let server = MockServer::start().await;
341 Mock::given(method("GET"))
342 .and(path("/"))
343 .respond_with(ResponseTemplate::new(429))
344 .mount(&server)
345 .await;
346
347 let (tx, mut rx) = mpsc::channel(1);
348 let handler = make_handler("0.11.0", tx, &server.uri());
349
350 let result = handler.execute(&serde_json::Value::Null).await;
351 assert!(result.is_ok(), "handler must return Ok on 429");
352 assert!(rx.try_recv().is_err(), "no notification expected on 429");
353 }
354
355 #[tokio::test]
356 async fn test_execute_http_500_server_error_graceful() {
357 let server = MockServer::start().await;
358 Mock::given(method("GET"))
359 .and(path("/"))
360 .respond_with(ResponseTemplate::new(500))
361 .mount(&server)
362 .await;
363
364 let (tx, mut rx) = mpsc::channel(1);
365 let handler = make_handler("0.11.0", tx, &server.uri());
366
367 let result = handler.execute(&serde_json::Value::Null).await;
368 assert!(result.is_ok(), "handler must return Ok on 500");
369 assert!(rx.try_recv().is_err(), "no notification expected on 500");
370 }
371
372 #[tokio::test]
373 async fn test_execute_malformed_json_graceful() {
374 let server = MockServer::start().await;
375 Mock::given(method("GET"))
376 .and(path("/"))
377 .respond_with(ResponseTemplate::new(200).set_body_string("this is not json {{{"))
378 .mount(&server)
379 .await;
380
381 let (tx, mut rx) = mpsc::channel(1);
382 let handler = make_handler("0.11.0", tx, &server.uri());
383
384 let result = handler.execute(&serde_json::Value::Null).await;
385 assert!(result.is_ok(), "handler must return Ok on malformed JSON");
386 assert!(
387 rx.try_recv().is_err(),
388 "no notification expected for malformed JSON"
389 );
390 }
391
392 #[tokio::test]
393 async fn test_execute_missing_tag_name_graceful() {
394 let server = MockServer::start().await;
395 Mock::given(method("GET"))
396 .and(path("/"))
397 .respond_with(ResponseTemplate::new(200).set_body_json(serde_json::json!({
398 "name": "Latest Release",
399 "published_at": "2024-01-01"
400 })))
401 .mount(&server)
402 .await;
403
404 let (tx, mut rx) = mpsc::channel(1);
405 let handler = make_handler("0.11.0", tx, &server.uri());
406
407 let result = handler.execute(&serde_json::Value::Null).await;
408 assert!(result.is_ok(), "handler must return Ok on missing tag_name");
409 assert!(
410 rx.try_recv().is_err(),
411 "no notification expected for missing tag_name"
412 );
413 }
414
415 #[tokio::test]
416 async fn test_execute_oversized_body_graceful() {
417 let server = MockServer::start().await;
418 let large_body = "x".repeat(MAX_RESPONSE_BYTES + 1);
420 Mock::given(method("GET"))
421 .and(path("/"))
422 .respond_with(ResponseTemplate::new(200).set_body_string(large_body))
423 .mount(&server)
424 .await;
425
426 let (tx, mut rx) = mpsc::channel(1);
427 let handler = make_handler("0.11.0", tx, &server.uri());
428
429 let result = handler.execute(&serde_json::Value::Null).await;
430 assert!(result.is_ok(), "handler must return Ok for oversized body");
431 assert!(
432 rx.try_recv().is_err(),
433 "no notification expected for oversized body"
434 );
435 }
436}