Skip to main content

update_kit/
config.rs

1use std::fmt;
2use std::future::Future;
3use std::pin::Pin;
4
5use serde::{Deserialize, Serialize};
6
7use crate::errors::UpdateKitError;
8use crate::types::{
9    AssetInfo, Channel, CheckMode, Confidence, DelegateMode, InstallDetection, PostAction,
10    UpdatePlan,
11};
12
13/// Package metadata (name and version).
14#[derive(Debug, Clone, Serialize, Deserialize)]
15pub struct PackageInfo {
16    pub name: String,
17    pub version: String,
18}
19
20/// Configuration for a version source.
21#[derive(Debug, Clone, Serialize, Deserialize)]
22#[serde(tag = "type", rename_all = "kebab-case")]
23pub enum VersionSourceConfig {
24    Github {
25        owner: String,
26        repo: String,
27        #[serde(skip_serializing_if = "Option::is_none")]
28        token: Option<String>,
29        #[serde(skip_serializing_if = "Option::is_none")]
30        api_base_url: Option<String>,
31    },
32    Npm {
33        package_name: String,
34        #[serde(skip_serializing_if = "Option::is_none")]
35        registry_url: Option<String>,
36    },
37    Jsr {
38        scope: String,
39        name: String,
40    },
41    Brew {
42        cask_name: String,
43    },
44    Custom {
45        url: String,
46        #[serde(skip_serializing_if = "Option::is_none")]
47        version_field: Option<String>,
48    },
49}
50
51/// Base configuration with optional fields grouped by pipeline stage.
52#[derive(Debug, Clone, Default, Serialize, Deserialize)]
53pub struct BaseConfig {
54    // Detection
55    #[serde(skip_serializing_if = "Option::is_none")]
56    pub repository: Option<String>,
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub brew_cask_name: Option<String>,
59    #[serde(skip_serializing_if = "Option::is_none")]
60    pub npm_package_name: Option<String>,
61
62    // Check
63    #[serde(skip_serializing_if = "Option::is_none")]
64    pub sources: Option<Vec<VersionSourceConfig>>,
65    #[serde(skip_serializing_if = "Option::is_none")]
66    pub check_interval_ms: Option<u64>,
67    #[serde(skip_serializing_if = "Option::is_none")]
68    pub check_mode: Option<CheckMode>,
69
70    // Plan
71    #[serde(skip_serializing_if = "Option::is_none")]
72    pub delegate_mode: Option<DelegateMode>,
73    #[serde(skip_serializing_if = "Option::is_none")]
74    pub post_action: Option<PostAction>,
75    #[serde(skip_serializing_if = "Option::is_none")]
76    pub min_confidence: Option<Confidence>,
77
78    // Plan (continued)
79    #[serde(skip_serializing_if = "Option::is_none")]
80    pub allow_reexec: Option<bool>,
81    #[serde(skip_serializing_if = "Option::is_none")]
82    pub asset_pattern: Option<String>,
83
84    // Apply
85    #[serde(skip_serializing_if = "Option::is_none")]
86    pub download_timeout_ms: Option<u64>,
87    #[serde(skip_serializing_if = "Option::is_none")]
88    pub delegate_timeout_ms: Option<u64>,
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub require_checksum: Option<bool>,
91}
92
93/// Top-level configuration for UpdateKit.
94#[derive(Debug, Clone, Serialize, Deserialize)]
95#[serde(untagged)]
96pub enum UpdateKitConfig {
97    Explicit {
98        app_name: String,
99        current_version: String,
100        #[serde(flatten)]
101        base: BaseConfig,
102    },
103    Pkg {
104        #[serde(skip_serializing_if = "Option::is_none")]
105        app_name: Option<String>,
106        #[serde(skip_serializing_if = "Option::is_none")]
107        current_version: Option<String>,
108        pkg: PackageInfo,
109        #[serde(flatten)]
110        base: BaseConfig,
111    },
112}
113
114/// Async hook that takes no arguments and returns `Result<(), UpdateKitError>`.
115pub type BeforeCheckHook =
116    Box<dyn Fn() -> Pin<Box<dyn Future<Output = Result<(), UpdateKitError>>>> + Send + Sync>;
117
118/// Async hook called before applying an update. Returns `Ok(true)` to proceed, `Ok(false)` to abort.
119pub type BeforeApplyHook = Box<
120    dyn Fn(&UpdatePlan) -> Pin<Box<dyn Future<Output = Result<bool, UpdateKitError>> + Send>>
121        + Send
122        + Sync,
123>;
124
125/// Async hook called after applying an update.
126pub type AfterApplyHook = Box<
127    dyn Fn(
128            &crate::types::ApplyResult,
129        ) -> Pin<Box<dyn Future<Output = Result<(), UpdateKitError>> + Send>>
130        + Send
131        + Sync,
132>;
133
134/// Synchronous error handler hook.
135pub type OnErrorHook = Box<dyn Fn(&UpdateKitError) + Send + Sync>;
136
137/// Lifecycle hooks for the update pipeline.
138#[derive(Default)]
139pub struct Hooks {
140    /// Called before checking for updates.
141    pub before_check: Option<BeforeCheckHook>,
142    /// Called before applying an update. Return `false` to abort.
143    pub before_apply: Option<BeforeApplyHook>,
144    /// Called after applying an update.
145    pub after_apply: Option<AfterApplyHook>,
146    /// Called on any error during the pipeline.
147    pub on_error: Option<OnErrorHook>,
148}
149
150impl fmt::Debug for Hooks {
151    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
152        f.debug_struct("Hooks")
153            .field("before_check", &self.before_check.as_ref().map(|_| "..."))
154            .field("before_apply", &self.before_apply.as_ref().map(|_| "..."))
155            .field("after_apply", &self.after_apply.as_ref().map(|_| "..."))
156            .field("on_error", &self.on_error.as_ref().map(|_| "..."))
157            .finish()
158    }
159}
160
161/// Async function type for custom detection.
162pub type DetectFn = Box<
163    dyn Fn() -> Pin<Box<dyn Future<Output = Result<Option<InstallDetection>, UpdateKitError>> + Send>>
164        + Send
165        + Sync,
166>;
167
168/// A custom detector that can be plugged into the detection pipeline.
169pub struct CustomDetector {
170    /// Name identifying this detector.
171    pub name: String,
172    /// The detection function.
173    pub detect: DetectFn,
174}
175
176impl fmt::Debug for CustomDetector {
177    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
178        f.debug_struct("CustomDetector")
179            .field("name", &self.name)
180            .field("detect", &"...")
181            .finish()
182    }
183}
184
185/// Context passed to a custom plan resolver.
186pub struct PlanResolverContext<'a> {
187    pub channel: &'a Channel,
188    pub confidence: &'a Confidence,
189    pub to_version: &'a str,
190    pub config: &'a BaseConfig,
191    pub assets: &'a Option<Vec<AssetInfo>>,
192    pub default_plan: &'a UpdatePlan,
193}
194
195/// Type alias for a custom plan resolver function.
196pub type CustomPlanResolver = Box<
197    dyn Fn(PlanResolverContext<'_>) -> Pin<Box<dyn Future<Output = Result<UpdatePlan, UpdateKitError>> + Send>>
198        + Send
199        + Sync,
200>;
201
202/// Fully resolved and validated configuration.
203#[derive(Debug, Clone)]
204pub struct ResolvedConfig {
205    pub app_name: String,
206    pub current_version: semver::Version,
207    pub base: BaseConfig,
208}
209
210impl TryFrom<UpdateKitConfig> for ResolvedConfig {
211    type Error = UpdateKitError;
212
213    fn try_from(config: UpdateKitConfig) -> Result<Self, Self::Error> {
214        let (app_name, version_str, base) = match config {
215            UpdateKitConfig::Explicit {
216                app_name,
217                current_version,
218                base,
219            } => (app_name, current_version, base),
220            UpdateKitConfig::Pkg {
221                app_name,
222                current_version,
223                pkg,
224                base,
225            } => {
226                let name = app_name.unwrap_or(pkg.name);
227                let version = current_version.unwrap_or(pkg.version);
228                (name, version, base)
229            }
230        };
231
232        if app_name.is_empty() {
233            return Err(UpdateKitError::VersionParse(
234                "app_name must not be empty".into(),
235            ));
236        }
237
238        if version_str.is_empty() {
239            return Err(UpdateKitError::VersionParse(
240                "current_version must not be empty".into(),
241            ));
242        }
243
244        let current_version = semver::Version::parse(&version_str).map_err(|e| {
245            UpdateKitError::VersionParse(format!(
246                "invalid semver '{}': {}",
247                version_str, e
248            ))
249        })?;
250
251        Ok(ResolvedConfig {
252            app_name,
253            current_version,
254            base,
255        })
256    }
257}
258
259#[cfg(test)]
260mod tests {
261    use super::*;
262
263    #[test]
264    fn test_package_info_creation() {
265        let pkg = PackageInfo {
266            name: "my-app".into(),
267            version: "1.0.0".into(),
268        };
269        assert_eq!(pkg.name, "my-app");
270        assert_eq!(pkg.version, "1.0.0");
271    }
272
273    #[test]
274    fn test_base_config_default() {
275        let config = BaseConfig::default();
276        assert!(config.repository.is_none());
277        assert!(config.sources.is_none());
278        assert!(config.check_mode.is_none());
279        assert!(config.delegate_mode.is_none());
280        assert!(config.require_checksum.is_none());
281    }
282
283    #[test]
284    fn test_resolved_config_explicit() {
285        let config = UpdateKitConfig::Explicit {
286            app_name: "my-app".into(),
287            current_version: "1.2.3".into(),
288            base: BaseConfig::default(),
289        };
290        let resolved = ResolvedConfig::try_from(config).unwrap();
291        assert_eq!(resolved.app_name, "my-app");
292        assert_eq!(resolved.current_version, semver::Version::new(1, 2, 3));
293    }
294
295    #[test]
296    fn test_resolved_config_pkg_fallback() {
297        let config = UpdateKitConfig::Pkg {
298            app_name: None,
299            current_version: None,
300            pkg: PackageInfo {
301                name: "pkg-app".into(),
302                version: "0.5.0".into(),
303            },
304            base: BaseConfig::default(),
305        };
306        let resolved = ResolvedConfig::try_from(config).unwrap();
307        assert_eq!(resolved.app_name, "pkg-app");
308        assert_eq!(resolved.current_version, semver::Version::new(0, 5, 0));
309    }
310
311    #[test]
312    fn test_resolved_config_pkg_override() {
313        let config = UpdateKitConfig::Pkg {
314            app_name: Some("override-name".into()),
315            current_version: Some("3.0.0".into()),
316            pkg: PackageInfo {
317                name: "pkg-app".into(),
318                version: "0.5.0".into(),
319            },
320            base: BaseConfig::default(),
321        };
322        let resolved = ResolvedConfig::try_from(config).unwrap();
323        assert_eq!(resolved.app_name, "override-name");
324        assert_eq!(resolved.current_version, semver::Version::new(3, 0, 0));
325    }
326
327    #[test]
328    fn test_resolved_config_empty_app_name() {
329        let config = UpdateKitConfig::Explicit {
330            app_name: "".into(),
331            current_version: "1.0.0".into(),
332            base: BaseConfig::default(),
333        };
334        let err = ResolvedConfig::try_from(config).unwrap_err();
335        assert_eq!(err.code(), "VERSION_PARSE");
336        assert!(err.to_string().contains("app_name"));
337    }
338
339    #[test]
340    fn test_resolved_config_empty_version() {
341        let config = UpdateKitConfig::Explicit {
342            app_name: "my-app".into(),
343            current_version: "".into(),
344            base: BaseConfig::default(),
345        };
346        let err = ResolvedConfig::try_from(config).unwrap_err();
347        assert_eq!(err.code(), "VERSION_PARSE");
348        assert!(err.to_string().contains("current_version"));
349    }
350
351    #[test]
352    fn test_resolved_config_invalid_semver() {
353        let config = UpdateKitConfig::Explicit {
354            app_name: "my-app".into(),
355            current_version: "not-a-version".into(),
356            base: BaseConfig::default(),
357        };
358        let err = ResolvedConfig::try_from(config).unwrap_err();
359        assert_eq!(err.code(), "VERSION_PARSE");
360        assert!(err.to_string().contains("invalid semver"));
361    }
362
363    #[test]
364    fn test_version_source_config_github_serialization() {
365        let source = VersionSourceConfig::Github {
366            owner: "user".into(),
367            repo: "project".into(),
368            token: None,
369            api_base_url: None,
370        };
371        let json = serde_json::to_value(&source).unwrap();
372        assert_eq!(json["type"], "github");
373        assert_eq!(json["owner"], "user");
374        assert_eq!(json["repo"], "project");
375    }
376
377    #[test]
378    fn test_version_source_config_npm_serialization() {
379        let source = VersionSourceConfig::Npm {
380            package_name: "@scope/pkg".into(),
381            registry_url: Some("https://registry.npmjs.org".into()),
382        };
383        let json = serde_json::to_value(&source).unwrap();
384        assert_eq!(json["type"], "npm");
385        assert_eq!(json["package_name"], "@scope/pkg");
386        assert_eq!(json["registry_url"], "https://registry.npmjs.org");
387    }
388
389    #[test]
390    fn test_version_source_config_brew_serialization() {
391        let source = VersionSourceConfig::Brew {
392            cask_name: "my-app".into(),
393        };
394        let json = serde_json::to_value(&source).unwrap();
395        assert_eq!(json["type"], "brew");
396        assert_eq!(json["cask_name"], "my-app");
397    }
398
399    #[test]
400    fn test_version_source_config_roundtrip() {
401        let source = VersionSourceConfig::Custom {
402            url: "https://api.example.com/version".into(),
403            version_field: Some("latest".into()),
404        };
405        let json = serde_json::to_string(&source).unwrap();
406        let deserialized: VersionSourceConfig = serde_json::from_str(&json).unwrap();
407        match deserialized {
408            VersionSourceConfig::Custom { url, version_field } => {
409                assert_eq!(url, "https://api.example.com/version");
410                assert_eq!(version_field.as_deref(), Some("latest"));
411            }
412            _ => panic!("expected Custom variant"),
413        }
414    }
415
416    #[test]
417    fn test_hooks_debug() {
418        let hooks = Hooks::default();
419        let debug = format!("{:?}", hooks);
420        assert!(debug.contains("Hooks"));
421    }
422
423    #[test]
424    fn test_custom_detector_debug() {
425        let detector = CustomDetector {
426            name: "test-detector".into(),
427            detect: Box::new(|| Box::pin(async { Ok(None) })),
428        };
429        let debug = format!("{:?}", detector);
430        assert!(debug.contains("test-detector"));
431    }
432}