Skip to main content

mcpkit_rs/bundle/
mod.rs

1//! Bundle distribution system for mcpkit-rs
2//!
3//! This module provides functionality for distributing WASM bundles via OCI registries.
4
5use std::path::Path;
6
7use sha2::{Digest, Sha256};
8
9pub mod cache;
10pub mod oci;
11
12pub use cache::BundleCache;
13pub use oci::{BundleClient, OciError};
14
15use crate::ErrorData;
16
17/// A bundle consisting of WASM module and configuration
18#[derive(Debug, Clone)]
19pub struct Bundle {
20    /// The WASM module bytes
21    pub wasm: Vec<u8>,
22
23    /// The configuration YAML bytes
24    pub config: Vec<u8>,
25
26    /// Bundle metadata
27    pub metadata: BundleMetadata,
28}
29
30/// Bundle metadata
31#[derive(Debug, Clone)]
32pub struct BundleMetadata {
33    /// Registry URI
34    pub registry: String,
35
36    /// Bundle version
37    pub version: String,
38
39    /// SHA256 digest of WASM module
40    pub wasm_digest: String,
41
42    /// SHA256 digest of config
43    pub config_digest: String,
44
45    /// Pull timestamp
46    pub pulled_at: std::time::SystemTime,
47}
48
49/// Bundle distribution errors
50#[derive(Debug, thiserror::Error)]
51pub enum BundleError {
52    #[error("OCI operation failed: {0}")]
53    OciError(#[from] OciError),
54
55    #[error("Cache operation failed: {0}")]
56    CacheError(#[from] cache::CacheError),
57
58    #[error("Digest mismatch - expected: {expected}, computed: {computed}")]
59    DigestMismatch { expected: String, computed: String },
60
61    #[error("Bundle not found: {0}")]
62    NotFound(String),
63
64    #[error("Invalid URI: {0}")]
65    InvalidUri(String),
66
67    #[error("Authentication failed: {0}")]
68    AuthenticationFailed(String),
69
70    #[error("IO error: {0}")]
71    IoError(#[from] std::io::Error),
72
73    #[error("Configuration error: {0}")]
74    ConfigError(String),
75}
76
77impl From<BundleError> for ErrorData {
78    fn from(err: BundleError) -> Self {
79        match err {
80            BundleError::NotFound(name) => {
81                ErrorData::invalid_request(format!("Bundle not found: {}", name), None)
82            }
83            BundleError::AuthenticationFailed(msg) => {
84                ErrorData::invalid_request(format!("Authentication failed: {}", msg), None)
85            }
86            _ => ErrorData::internal_error(err.to_string(), None),
87        }
88    }
89}
90
91/// Compute SHA256 digest of content
92pub fn compute_digest(content: &[u8]) -> String {
93    let mut hasher = Sha256::new();
94    hasher.update(content);
95    format!("sha256:{}", hex::encode(hasher.finalize()))
96}
97
98/// Verify content against expected digest
99pub fn verify_digest(content: &[u8], expected: &str) -> Result<(), BundleError> {
100    let computed = compute_digest(content);
101    if computed != expected {
102        return Err(BundleError::DigestMismatch {
103            expected: expected.to_string(),
104            computed,
105        });
106    }
107    Ok(())
108}
109
110/// Parse OCI URI into registry, repository, and tag
111pub fn parse_oci_uri(uri: &str) -> Result<(String, String, Option<String>), BundleError> {
112    if !uri.starts_with("oci://") {
113        return Err(BundleError::InvalidUri(format!(
114            "URI must start with 'oci://': {}",
115            uri
116        )));
117    }
118
119    let uri = &uri[6..]; // Remove "oci://" prefix
120
121    // Split into registry and path
122    let parts: Vec<&str> = uri.splitn(2, '/').collect();
123    if parts.len() != 2 {
124        return Err(BundleError::InvalidUri(format!(
125            "Invalid OCI URI format: {}",
126            uri
127        )));
128    }
129
130    let registry = parts[0];
131    let path_and_tag = parts[1];
132
133    // Split repository and tag/digest
134    let (repository, tag) = if let Some(at_pos) = path_and_tag.rfind('@') {
135        // Digest reference (e.g., @sha256:abc123)
136        let repo = &path_and_tag[..at_pos];
137        let digest = &path_and_tag[at_pos + 1..];
138        (repo.to_string(), Some(digest.to_string()))
139    } else if let Some(colon_pos) = path_and_tag.rfind(':') {
140        // Tag reference (e.g., :v1.0.0)
141        let repo = &path_and_tag[..colon_pos];
142        let tag = &path_and_tag[colon_pos + 1..];
143        (repo.to_string(), Some(tag.to_string()))
144    } else {
145        // No tag or digest
146        (path_and_tag.to_string(), None)
147    };
148
149    Ok((registry.to_string(), repository, tag))
150}
151
152impl Bundle {
153    /// Create a new bundle from WASM and config bytes
154    pub fn new(wasm: Vec<u8>, config: Vec<u8>, registry: String, version: String) -> Self {
155        let wasm_digest = compute_digest(&wasm);
156        let config_digest = compute_digest(&config);
157
158        Self {
159            wasm,
160            config,
161            metadata: BundleMetadata {
162                registry,
163                version,
164                wasm_digest,
165                config_digest,
166                pulled_at: std::time::SystemTime::now(),
167            },
168        }
169    }
170
171    /// Load bundle from filesystem
172    pub fn from_directory(path: &Path) -> Result<Self, BundleError> {
173        let wasm_path = path.join("module.wasm");
174        let config_path = path.join("config.yaml");
175        let metadata_path = path.join("metadata.json");
176
177        if !wasm_path.exists() || !config_path.exists() {
178            return Err(BundleError::NotFound(path.display().to_string()));
179        }
180
181        let wasm = std::fs::read(&wasm_path)?;
182        let config = std::fs::read(&config_path)?;
183
184        // Load metadata if it exists
185        let metadata = if metadata_path.exists() {
186            let metadata_str = std::fs::read_to_string(&metadata_path)?;
187            serde_json::from_str(&metadata_str)
188                .map_err(|e| BundleError::ConfigError(e.to_string()))?
189        } else {
190            // Create default metadata
191            BundleMetadata {
192                registry: String::new(),
193                version: String::new(),
194                wasm_digest: compute_digest(&wasm),
195                config_digest: compute_digest(&config),
196                pulled_at: std::time::SystemTime::now(),
197            }
198        };
199
200        Ok(Self {
201            wasm,
202            config,
203            metadata,
204        })
205    }
206
207    /// Save bundle to filesystem
208    pub fn save_to_directory(&self, path: &Path) -> Result<(), BundleError> {
209        std::fs::create_dir_all(path)?;
210
211        std::fs::write(path.join("module.wasm"), &self.wasm)?;
212        std::fs::write(path.join("config.yaml"), &self.config)?;
213
214        let metadata_json = serde_json::to_string_pretty(&self.metadata)
215            .map_err(|e| BundleError::ConfigError(e.to_string()))?;
216        std::fs::write(path.join("metadata.json"), metadata_json)?;
217
218        Ok(())
219    }
220
221    /// Verify bundle integrity
222    pub fn verify(&self) -> Result<(), BundleError> {
223        verify_digest(&self.wasm, &self.metadata.wasm_digest)?;
224        verify_digest(&self.config, &self.metadata.config_digest)?;
225        Ok(())
226    }
227}
228
229// Implement Serialize for BundleMetadata so it can be saved
230impl serde::Serialize for BundleMetadata {
231    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
232    where
233        S: serde::Serializer,
234    {
235        use serde::ser::SerializeStruct;
236
237        let mut state = serializer.serialize_struct("BundleMetadata", 5)?;
238        state.serialize_field("registry", &self.registry)?;
239        state.serialize_field("version", &self.version)?;
240        state.serialize_field("wasm_digest", &self.wasm_digest)?;
241        state.serialize_field("config_digest", &self.config_digest)?;
242
243        // Serialize SystemTime as ISO8601 string
244        let duration = self
245            .pulled_at
246            .duration_since(std::time::UNIX_EPOCH)
247            .unwrap_or_default();
248        let timestamp = duration.as_secs();
249        state.serialize_field("pulled_at", &timestamp)?;
250
251        state.end()
252    }
253}
254
255// Implement Deserialize for BundleMetadata
256impl<'de> serde::Deserialize<'de> for BundleMetadata {
257    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
258    where
259        D: serde::Deserializer<'de>,
260    {
261        #[derive(serde::Deserialize)]
262        struct Helper {
263            registry: String,
264            version: String,
265            wasm_digest: String,
266            config_digest: String,
267            pulled_at: u64,
268        }
269
270        let helper = Helper::deserialize(deserializer)?;
271
272        Ok(BundleMetadata {
273            registry: helper.registry,
274            version: helper.version,
275            wasm_digest: helper.wasm_digest,
276            config_digest: helper.config_digest,
277            pulled_at: std::time::UNIX_EPOCH + std::time::Duration::from_secs(helper.pulled_at),
278        })
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_parse_oci_uri() {
288        // Test with tag
289        let (registry, repo, tag) = parse_oci_uri("oci://ghcr.io/org/tool:v1.0.0").unwrap();
290        assert_eq!(registry, "ghcr.io");
291        assert_eq!(repo, "org/tool");
292        assert_eq!(tag, Some("v1.0.0".to_string()));
293
294        // Test with digest
295        let (registry, repo, tag) =
296            parse_oci_uri("oci://docker.io/org/tool@sha256:abc123").unwrap();
297        assert_eq!(registry, "docker.io");
298        assert_eq!(repo, "org/tool");
299        assert_eq!(tag, Some("sha256:abc123".to_string()));
300
301        // Test without tag
302        let (registry, repo, tag) = parse_oci_uri("oci://ghcr.io/org/tool").unwrap();
303        assert_eq!(registry, "ghcr.io");
304        assert_eq!(repo, "org/tool");
305        assert_eq!(tag, None);
306
307        // Test invalid URI
308        assert!(parse_oci_uri("https://ghcr.io/org/tool").is_err());
309    }
310
311    #[test]
312    fn test_compute_digest() {
313        let content = b"test content";
314        let digest = compute_digest(content);
315        assert!(digest.starts_with("sha256:"));
316        assert_eq!(digest.len(), 71); // "sha256:" + 64 hex chars
317    }
318
319    #[test]
320    fn test_verify_digest() {
321        let content = b"test content";
322        let digest = compute_digest(content);
323
324        // Valid digest
325        assert!(verify_digest(content, &digest).is_ok());
326
327        // Invalid digest
328        assert!(verify_digest(b"different content", &digest).is_err());
329    }
330
331    #[test]
332    fn test_bundle_creation() {
333        let wasm = vec![0x00, 0x61, 0x73, 0x6d]; // WASM magic number
334        let config = b"version: 1.0".to_vec();
335
336        let bundle = Bundle::new(
337            wasm.clone(),
338            config.clone(),
339            "ghcr.io/test/bundle".to_string(),
340            "1.0.0".to_string(),
341        );
342
343        assert_eq!(bundle.wasm, wasm);
344        assert_eq!(bundle.config, config);
345        assert_eq!(bundle.metadata.registry, "ghcr.io/test/bundle");
346        assert_eq!(bundle.metadata.version, "1.0.0");
347
348        // Verify integrity
349        assert!(bundle.verify().is_ok());
350    }
351}