Skip to main content

briefcase_core/storage/
lakefs.rs

1use super::{FlushResult, SnapshotQuery, StorageBackend, StorageError};
2use crate::models::{DecisionSnapshot, Snapshot};
3#[cfg(feature = "networking")]
4use base64::{engine::general_purpose, Engine as _};
5#[cfg(feature = "networking")]
6use reqwest::{
7    header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
8    Client,
9};
10use serde::{Deserialize, Serialize};
11use std::sync::{Arc, Mutex};
12
13#[cfg(all(feature = "async", feature = "networking"))]
14#[derive(Clone)]
15pub struct LakeFSBackend {
16    client: Client,
17    endpoint: String,
18    repository: String, // "engagement" in Briefcase terminology
19    branch: String,     // "workstream" in Briefcase terminology
20    #[allow(dead_code)]
21    access_key: String,
22    #[allow(dead_code)]
23    secret_key: String,
24    pending_writes: Arc<Mutex<Vec<Snapshot>>>,
25}
26
27#[cfg(all(feature = "async", feature = "networking"))]
28impl LakeFSBackend {
29    pub fn new(config: LakeFSConfig) -> Result<Self, StorageError> {
30        let mut headers = HeaderMap::new();
31
32        // Create basic auth header
33        let credentials = format!("{}:{}", config.access_key, config.secret_key);
34        let encoded = general_purpose::STANDARD.encode(credentials.as_bytes());
35        let auth_header = format!("Basic {}", encoded);
36
37        headers.insert(
38            AUTHORIZATION,
39            HeaderValue::from_str(&auth_header).map_err(|e| {
40                StorageError::ConnectionError(format!("Invalid auth header: {}", e))
41            })?,
42        );
43
44        headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
45
46        let client = Client::builder()
47            .default_headers(headers)
48            .timeout(std::time::Duration::from_secs(30))
49            .build()
50            .map_err(|e| {
51                StorageError::ConnectionError(format!("Failed to create HTTP client: {}", e))
52            })?;
53
54        Ok(Self {
55            client,
56            endpoint: config.endpoint.trim_end_matches('/').to_string(),
57            repository: config.repository,
58            branch: config.branch,
59            access_key: config.access_key,
60            secret_key: config.secret_key,
61            pending_writes: Arc::new(Mutex::new(Vec::new())),
62        })
63    }
64
65    /// Create a commit (checkpoint) with pending writes
66    async fn create_commit(&self, message: &str) -> Result<String, StorageError> {
67        let url = format!(
68            "{}/repositories/{}/branches/{}/commits",
69            self.endpoint, self.repository, self.branch
70        );
71
72        #[derive(Serialize)]
73        struct CommitRequest {
74            message: String,
75            metadata: std::collections::HashMap<String, String>,
76        }
77
78        let mut metadata = std::collections::HashMap::new();
79        metadata.insert("source".to_string(), "briefcase-ai".to_string());
80
81        let request = CommitRequest {
82            message: message.to_string(),
83            metadata,
84        };
85
86        let response = self
87            .client
88            .post(&url)
89            .json(&request)
90            .send()
91            .await
92            .map_err(|e| {
93                StorageError::ConnectionError(format!("Failed to create commit: {}", e))
94            })?;
95
96        let status = response.status();
97        if !status.is_success() {
98            let error_text = response.text().await.unwrap_or_default();
99            return Err(StorageError::ConnectionError(format!(
100                "Commit failed with status {}: {}",
101                status, error_text
102            )));
103        }
104
105        #[derive(Deserialize)]
106        struct CommitResponse {
107            id: String,
108        }
109
110        let commit_response: CommitResponse = response.json().await.map_err(|e| {
111            StorageError::SerializationError(format!("Failed to parse commit response: {}", e))
112        })?;
113
114        Ok(commit_response.id)
115    }
116
117    /// Upload object to LakeFS
118    async fn upload_object(&self, path: &str, data: &[u8]) -> Result<(), StorageError> {
119        let url = format!(
120            "{}/repositories/{}/branches/{}/objects",
121            self.endpoint, self.repository, self.branch
122        );
123
124        let response = self
125            .client
126            .put(&url)
127            .query(&[("path", path)])
128            .header("Content-Type", "application/octet-stream")
129            .body(data.to_vec())
130            .send()
131            .await
132            .map_err(|e| {
133                StorageError::ConnectionError(format!("Failed to upload object: {}", e))
134            })?;
135
136        let status = response.status();
137        if !status.is_success() {
138            let error_text = response.text().await.unwrap_or_default();
139            return Err(StorageError::ConnectionError(format!(
140                "Upload failed with status {}: {}",
141                status, error_text
142            )));
143        }
144
145        Ok(())
146    }
147
148    /// Download object from LakeFS
149    async fn download_object(&self, path: &str) -> Result<Vec<u8>, StorageError> {
150        let url = format!(
151            "{}/repositories/{}/refs/{}/objects",
152            self.endpoint, self.repository, self.branch
153        );
154
155        let response = self
156            .client
157            .get(&url)
158            .query(&[("path", path)])
159            .send()
160            .await
161            .map_err(|e| {
162                StorageError::ConnectionError(format!("Failed to download object: {}", e))
163            })?;
164
165        if response.status() == reqwest::StatusCode::NOT_FOUND {
166            return Err(StorageError::NotFound(format!(
167                "Object not found: {}",
168                path
169            )));
170        }
171
172        let status = response.status();
173        if !status.is_success() {
174            let error_text = response.text().await.unwrap_or_default();
175            return Err(StorageError::ConnectionError(format!(
176                "Download failed with status {}: {}",
177                status, error_text
178            )));
179        }
180
181        let data = response.bytes().await.map_err(|e| {
182            StorageError::ConnectionError(format!("Failed to read response: {}", e))
183        })?;
184
185        Ok(data.to_vec())
186    }
187
188    /// List objects with prefix
189    async fn list_objects(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
190        let url = format!(
191            "{}/repositories/{}/refs/{}/objects/ls",
192            self.endpoint, self.repository, self.branch
193        );
194
195        let response = self
196            .client
197            .get(&url)
198            .query(&[("prefix", prefix)])
199            .send()
200            .await
201            .map_err(|e| StorageError::ConnectionError(format!("Failed to list objects: {}", e)))?;
202
203        let status = response.status();
204        if !status.is_success() {
205            let error_text = response.text().await.unwrap_or_default();
206            return Err(StorageError::ConnectionError(format!(
207                "List failed with status {}: {}",
208                status, error_text
209            )));
210        }
211
212        #[derive(Deserialize)]
213        struct ListResponse {
214            results: Vec<ObjectInfo>,
215        }
216
217        #[derive(Deserialize)]
218        struct ObjectInfo {
219            path: String,
220            #[serde(rename = "type")]
221            object_type: String,
222        }
223
224        let list_response: ListResponse = response.json().await.map_err(|e| {
225            StorageError::SerializationError(format!("Failed to parse list response: {}", e))
226        })?;
227
228        let paths = list_response
229            .results
230            .into_iter()
231            .filter(|obj| obj.object_type == "object")
232            .map(|obj| obj.path)
233            .collect();
234
235        Ok(paths)
236    }
237
238    /// Generate object path for snapshot
239    fn snapshot_path(&self, snapshot_id: &str) -> String {
240        format!("snapshots/{}.json", snapshot_id)
241    }
242
243    /// Generate object path for decision
244    fn decision_path(&self, decision_id: &str) -> String {
245        format!("decisions/{}.json", decision_id)
246    }
247}
248
249#[cfg(all(feature = "async", feature = "networking"))]
250#[async_trait::async_trait]
251impl StorageBackend for LakeFSBackend {
252    async fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
253        let snapshot_id = snapshot.metadata.snapshot_id.to_string();
254
255        // Add to pending writes instead of immediate upload
256        {
257            let mut pending = self.pending_writes.lock().unwrap();
258            pending.push(snapshot.clone());
259        }
260
261        Ok(snapshot_id)
262    }
263
264    async fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
265        let decision_id = decision.metadata.snapshot_id.to_string();
266        let path = self.decision_path(&decision_id);
267
268        let json_data = serde_json::to_vec(decision)
269            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
270
271        self.upload_object(&path, &json_data).await?;
272
273        Ok(decision_id)
274    }
275
276    async fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
277        let path = self.snapshot_path(snapshot_id);
278        let data = self.download_object(&path).await?;
279
280        let snapshot: Snapshot = serde_json::from_slice(&data)
281            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
282
283        Ok(snapshot)
284    }
285
286    async fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
287        let path = self.decision_path(decision_id);
288        let data = self.download_object(&path).await?;
289
290        let decision: DecisionSnapshot = serde_json::from_slice(&data)
291            .map_err(|e| StorageError::SerializationError(e.to_string()))?;
292
293        Ok(decision)
294    }
295
296    async fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
297        // List all snapshots
298        let paths = self.list_objects("snapshots/").await?;
299
300        let mut snapshots = Vec::new();
301        let mut count = 0;
302        let offset = query.offset.unwrap_or(0);
303        let limit = query.limit.unwrap_or(usize::MAX);
304
305        for path in paths {
306            if let Some(filename) = path.split('/').next_back() {
307                if let Some(snapshot_id) = filename.strip_suffix(".json") {
308                    // Load snapshot to check filters
309                    match self.load(snapshot_id).await {
310                        Ok(snapshot) => {
311                            // Apply filters
312                            if self.matches_query(&snapshot, &query) {
313                                if count >= offset {
314                                    snapshots.push(snapshot);
315                                    if snapshots.len() >= limit {
316                                        break;
317                                    }
318                                }
319                                count += 1;
320                            }
321                        }
322                        Err(_) => continue, // Skip invalid snapshots
323                    }
324                }
325            }
326        }
327
328        // Sort by timestamp (newest first)
329        snapshots.sort_by(|a, b| b.metadata.timestamp.cmp(&a.metadata.timestamp));
330
331        Ok(snapshots)
332    }
333
334    async fn delete(&self, _snapshot_id: &str) -> Result<bool, StorageError> {
335        // LakeFS doesn't have a direct delete API for objects in branches
336        // This would typically involve creating a new commit without the object
337        // For simplicity, we'll return an error indicating this operation isn't supported
338        Err(StorageError::PermissionDenied(
339            "LakeFS doesn't support direct object deletion. Use branch operations instead."
340                .to_string(),
341        ))
342    }
343
344    async fn flush(&self) -> Result<FlushResult, StorageError> {
345        let pending_snapshots = {
346            let mut pending = self.pending_writes.lock().unwrap();
347            let snapshots = pending.clone();
348            pending.clear();
349            snapshots
350        };
351
352        let mut bytes_written = 0;
353
354        // Upload all pending snapshots
355        for snapshot in &pending_snapshots {
356            let snapshot_id = snapshot.metadata.snapshot_id.to_string();
357            let path = self.snapshot_path(&snapshot_id);
358
359            let json_data = serde_json::to_vec(snapshot)
360                .map_err(|e| StorageError::SerializationError(e.to_string()))?;
361
362            bytes_written += json_data.len();
363
364            self.upload_object(&path, &json_data).await?;
365
366            // Also upload individual decisions
367            for decision in &snapshot.decisions {
368                let decision_id = decision.metadata.snapshot_id.to_string();
369                let decision_path = self.decision_path(&decision_id);
370
371                let decision_data = serde_json::to_vec(decision)
372                    .map_err(|e| StorageError::SerializationError(e.to_string()))?;
373
374                bytes_written += decision_data.len();
375                self.upload_object(&decision_path, &decision_data).await?;
376            }
377        }
378
379        // Create commit
380        let commit_message = format!("Briefcase AI flush: {} snapshots", pending_snapshots.len());
381        let commit_id = self.create_commit(&commit_message).await?;
382
383        Ok(FlushResult {
384            snapshots_written: pending_snapshots.len(),
385            bytes_written,
386            checkpoint_id: Some(commit_id),
387        })
388    }
389
390    async fn health_check(&self) -> Result<bool, StorageError> {
391        let url = format!("{}/repositories/{}", self.endpoint, self.repository);
392
393        let response =
394            self.client.get(&url).send().await.map_err(|e| {
395                StorageError::ConnectionError(format!("Health check failed: {}", e))
396            })?;
397
398        Ok(response.status().is_success())
399    }
400}
401
402#[cfg(all(feature = "async", feature = "networking"))]
403impl LakeFSBackend {
404    /// Check if snapshot matches query filters
405    fn matches_query(&self, snapshot: &Snapshot, query: &SnapshotQuery) -> bool {
406        // Check time range
407        if let Some(start_time) = query.start_time {
408            if snapshot.metadata.timestamp < start_time {
409                return false;
410            }
411        }
412
413        if let Some(end_time) = query.end_time {
414            if snapshot.metadata.timestamp > end_time {
415                return false;
416            }
417        }
418
419        // Check function name, module name, model name, tags in decisions
420        if query.function_name.is_some()
421            || query.module_name.is_some()
422            || query.model_name.is_some()
423            || query.tags.is_some()
424        {
425            let mut found_match = false;
426
427            for decision in &snapshot.decisions {
428                let mut decision_matches = true;
429
430                if let Some(function_name) = &query.function_name {
431                    if decision.function_name != *function_name {
432                        decision_matches = false;
433                    }
434                }
435
436                if let Some(module_name) = &query.module_name {
437                    if decision.module_name.as_ref() != Some(module_name) {
438                        decision_matches = false;
439                    }
440                }
441
442                if let Some(model_name) = &query.model_name {
443                    if let Some(model_params) = &decision.model_parameters {
444                        if model_params.model_name != *model_name {
445                            decision_matches = false;
446                        }
447                    } else {
448                        decision_matches = false;
449                    }
450                }
451
452                if let Some(query_tags) = &query.tags {
453                    for (key, value) in query_tags {
454                        if decision.tags.get(key) != Some(value) {
455                            decision_matches = false;
456                            break;
457                        }
458                    }
459                }
460
461                if decision_matches {
462                    found_match = true;
463                    break;
464                }
465            }
466
467            if !found_match {
468                return false;
469            }
470        }
471
472        true
473    }
474}
475
476#[derive(Debug, Clone)]
477pub struct LakeFSConfig {
478    pub endpoint: String,
479    pub repository: String,
480    pub branch: String,
481    pub access_key: String,
482    pub secret_key: String,
483}
484
485impl LakeFSConfig {
486    pub fn new(
487        endpoint: impl Into<String>,
488        repository: impl Into<String>,
489        branch: impl Into<String>,
490        access_key: impl Into<String>,
491        secret_key: impl Into<String>,
492    ) -> Self {
493        Self {
494            endpoint: endpoint.into(),
495            repository: repository.into(),
496            branch: branch.into(),
497            access_key: access_key.into(),
498            secret_key: secret_key.into(),
499        }
500    }
501}
502
503#[cfg(test)]
504mod tests {
505    use super::*;
506    use crate::models::*;
507    use serde_json::json;
508
509    fn create_test_config() -> LakeFSConfig {
510        LakeFSConfig::new(
511            "http://localhost:8000",
512            "briefcase-test",
513            "main",
514            "test_access_key",
515            "test_secret_key",
516        )
517    }
518
519    async fn create_test_snapshot() -> Snapshot {
520        let input = Input::new("test_input", json!("value"), "string");
521        let output = Output::new("test_output", json!("result"), "string");
522        let model_params = ModelParameters::new("gpt-4");
523
524        let decision = DecisionSnapshot::new("test_function")
525            .with_module("test_module")
526            .add_input(input)
527            .add_output(output)
528            .with_model_parameters(model_params)
529            .add_tag("env", "test");
530
531        let mut snapshot = Snapshot::new(SnapshotType::Session);
532        snapshot.add_decision(decision);
533        snapshot
534    }
535
536    #[tokio::test]
537    async fn test_lakefs_config_creation() {
538        let config = create_test_config();
539        assert_eq!(config.endpoint, "http://localhost:8000");
540        assert_eq!(config.repository, "briefcase-test");
541        assert_eq!(config.branch, "main");
542    }
543
544    #[tokio::test]
545    async fn test_object_paths() {
546        let config = create_test_config();
547        let backend = LakeFSBackend::new(config).unwrap();
548
549        let snapshot_id = "test-snapshot-123";
550        let decision_id = "test-decision-456";
551
552        assert_eq!(
553            backend.snapshot_path(snapshot_id),
554            "snapshots/test-snapshot-123.json"
555        );
556        assert_eq!(
557            backend.decision_path(decision_id),
558            "decisions/test-decision-456.json"
559        );
560    }
561
562    #[tokio::test]
563    async fn test_query_matching() {
564        let config = create_test_config();
565        let backend = LakeFSBackend::new(config).unwrap();
566        let snapshot = create_test_snapshot().await;
567
568        // Test function name matching
569        let query = SnapshotQuery::new().with_function_name("test_function");
570        assert!(backend.matches_query(&snapshot, &query));
571
572        let query = SnapshotQuery::new().with_function_name("other_function");
573        assert!(!backend.matches_query(&snapshot, &query));
574
575        // Test tag matching
576        let query = SnapshotQuery::new().with_tag("env", "test");
577        assert!(backend.matches_query(&snapshot, &query));
578
579        let query = SnapshotQuery::new().with_tag("env", "prod");
580        assert!(!backend.matches_query(&snapshot, &query));
581
582        // Test model name matching
583        let query = SnapshotQuery::new().with_model_name("gpt-4");
584        assert!(backend.matches_query(&snapshot, &query));
585
586        let query = SnapshotQuery::new().with_model_name("claude-3");
587        assert!(!backend.matches_query(&snapshot, &query));
588    }
589
590    #[tokio::test]
591    async fn test_pending_writes() {
592        let config = create_test_config();
593        let backend = LakeFSBackend::new(config).unwrap();
594        let snapshot = create_test_snapshot().await;
595
596        // Save should add to pending writes
597        let snapshot_id = backend.save(&snapshot).await.unwrap();
598        assert_eq!(snapshot_id, snapshot.metadata.snapshot_id.to_string());
599
600        // Check pending writes
601        {
602            let pending = backend.pending_writes.lock().unwrap();
603            assert_eq!(pending.len(), 1);
604        }
605    }
606
607    // Note: Integration tests against real LakeFS would require a running instance
608    // These would be better suited for a separate integration test suite
609}