1use super::{FlushResult, SnapshotQuery, StorageBackend, StorageError};
2use crate::models::{DecisionSnapshot, Snapshot};
3#[cfg(feature = "networking")]
4use base64::{engine::general_purpose, Engine as _};
5#[cfg(feature = "networking")]
6use reqwest::{
7 header::{HeaderMap, HeaderValue, AUTHORIZATION, CONTENT_TYPE},
8 Client,
9};
10use serde::{Deserialize, Serialize};
11use std::sync::{Arc, Mutex};
12
13#[cfg(all(feature = "async", feature = "networking"))]
14#[derive(Clone)]
15pub struct LakeFSBackend {
16 client: Client,
17 endpoint: String,
18 repository: String, branch: String, #[allow(dead_code)]
21 access_key: String,
22 #[allow(dead_code)]
23 secret_key: String,
24 pending_writes: Arc<Mutex<Vec<Snapshot>>>,
25}
26
27#[cfg(all(feature = "async", feature = "networking"))]
28impl LakeFSBackend {
29 pub fn new(config: LakeFSConfig) -> Result<Self, StorageError> {
30 let mut headers = HeaderMap::new();
31
32 let credentials = format!("{}:{}", config.access_key, config.secret_key);
34 let encoded = general_purpose::STANDARD.encode(credentials.as_bytes());
35 let auth_header = format!("Basic {}", encoded);
36
37 headers.insert(
38 AUTHORIZATION,
39 HeaderValue::from_str(&auth_header).map_err(|e| {
40 StorageError::ConnectionError(format!("Invalid auth header: {}", e))
41 })?,
42 );
43
44 headers.insert(CONTENT_TYPE, HeaderValue::from_static("application/json"));
45
46 let client = Client::builder()
47 .default_headers(headers)
48 .timeout(std::time::Duration::from_secs(30))
49 .build()
50 .map_err(|e| {
51 StorageError::ConnectionError(format!("Failed to create HTTP client: {}", e))
52 })?;
53
54 Ok(Self {
55 client,
56 endpoint: config.endpoint.trim_end_matches('/').to_string(),
57 repository: config.repository,
58 branch: config.branch,
59 access_key: config.access_key,
60 secret_key: config.secret_key,
61 pending_writes: Arc::new(Mutex::new(Vec::new())),
62 })
63 }
64
65 async fn create_commit(&self, message: &str) -> Result<String, StorageError> {
67 let url = format!(
68 "{}/repositories/{}/branches/{}/commits",
69 self.endpoint, self.repository, self.branch
70 );
71
72 #[derive(Serialize)]
73 struct CommitRequest {
74 message: String,
75 metadata: std::collections::HashMap<String, String>,
76 }
77
78 let mut metadata = std::collections::HashMap::new();
79 metadata.insert("source".to_string(), "briefcase-ai".to_string());
80
81 let request = CommitRequest {
82 message: message.to_string(),
83 metadata,
84 };
85
86 let response = self
87 .client
88 .post(&url)
89 .json(&request)
90 .send()
91 .await
92 .map_err(|e| {
93 StorageError::ConnectionError(format!("Failed to create commit: {}", e))
94 })?;
95
96 let status = response.status();
97 if !status.is_success() {
98 let error_text = response.text().await.unwrap_or_default();
99 return Err(StorageError::ConnectionError(format!(
100 "Commit failed with status {}: {}",
101 status, error_text
102 )));
103 }
104
105 #[derive(Deserialize)]
106 struct CommitResponse {
107 id: String,
108 }
109
110 let commit_response: CommitResponse = response.json().await.map_err(|e| {
111 StorageError::SerializationError(format!("Failed to parse commit response: {}", e))
112 })?;
113
114 Ok(commit_response.id)
115 }
116
117 async fn upload_object(&self, path: &str, data: &[u8]) -> Result<(), StorageError> {
119 let url = format!(
120 "{}/repositories/{}/branches/{}/objects",
121 self.endpoint, self.repository, self.branch
122 );
123
124 let response = self
125 .client
126 .put(&url)
127 .query(&[("path", path)])
128 .header("Content-Type", "application/octet-stream")
129 .body(data.to_vec())
130 .send()
131 .await
132 .map_err(|e| {
133 StorageError::ConnectionError(format!("Failed to upload object: {}", e))
134 })?;
135
136 let status = response.status();
137 if !status.is_success() {
138 let error_text = response.text().await.unwrap_or_default();
139 return Err(StorageError::ConnectionError(format!(
140 "Upload failed with status {}: {}",
141 status, error_text
142 )));
143 }
144
145 Ok(())
146 }
147
148 async fn download_object(&self, path: &str) -> Result<Vec<u8>, StorageError> {
150 let url = format!(
151 "{}/repositories/{}/refs/{}/objects",
152 self.endpoint, self.repository, self.branch
153 );
154
155 let response = self
156 .client
157 .get(&url)
158 .query(&[("path", path)])
159 .send()
160 .await
161 .map_err(|e| {
162 StorageError::ConnectionError(format!("Failed to download object: {}", e))
163 })?;
164
165 if response.status() == reqwest::StatusCode::NOT_FOUND {
166 return Err(StorageError::NotFound(format!(
167 "Object not found: {}",
168 path
169 )));
170 }
171
172 let status = response.status();
173 if !status.is_success() {
174 let error_text = response.text().await.unwrap_or_default();
175 return Err(StorageError::ConnectionError(format!(
176 "Download failed with status {}: {}",
177 status, error_text
178 )));
179 }
180
181 let data = response.bytes().await.map_err(|e| {
182 StorageError::ConnectionError(format!("Failed to read response: {}", e))
183 })?;
184
185 Ok(data.to_vec())
186 }
187
188 async fn list_objects(&self, prefix: &str) -> Result<Vec<String>, StorageError> {
190 let url = format!(
191 "{}/repositories/{}/refs/{}/objects/ls",
192 self.endpoint, self.repository, self.branch
193 );
194
195 let response = self
196 .client
197 .get(&url)
198 .query(&[("prefix", prefix)])
199 .send()
200 .await
201 .map_err(|e| StorageError::ConnectionError(format!("Failed to list objects: {}", e)))?;
202
203 let status = response.status();
204 if !status.is_success() {
205 let error_text = response.text().await.unwrap_or_default();
206 return Err(StorageError::ConnectionError(format!(
207 "List failed with status {}: {}",
208 status, error_text
209 )));
210 }
211
212 #[derive(Deserialize)]
213 struct ListResponse {
214 results: Vec<ObjectInfo>,
215 }
216
217 #[derive(Deserialize)]
218 struct ObjectInfo {
219 path: String,
220 #[serde(rename = "type")]
221 object_type: String,
222 }
223
224 let list_response: ListResponse = response.json().await.map_err(|e| {
225 StorageError::SerializationError(format!("Failed to parse list response: {}", e))
226 })?;
227
228 let paths = list_response
229 .results
230 .into_iter()
231 .filter(|obj| obj.object_type == "object")
232 .map(|obj| obj.path)
233 .collect();
234
235 Ok(paths)
236 }
237
238 fn snapshot_path(&self, snapshot_id: &str) -> String {
240 format!("snapshots/{}.json", snapshot_id)
241 }
242
243 fn decision_path(&self, decision_id: &str) -> String {
245 format!("decisions/{}.json", decision_id)
246 }
247}
248
249#[cfg(all(feature = "async", feature = "networking"))]
250#[async_trait::async_trait]
251impl StorageBackend for LakeFSBackend {
252 async fn save(&self, snapshot: &Snapshot) -> Result<String, StorageError> {
253 let snapshot_id = snapshot.metadata.snapshot_id.to_string();
254
255 {
257 let mut pending = self.pending_writes.lock().unwrap();
258 pending.push(snapshot.clone());
259 }
260
261 Ok(snapshot_id)
262 }
263
264 async fn save_decision(&self, decision: &DecisionSnapshot) -> Result<String, StorageError> {
265 let decision_id = decision.metadata.snapshot_id.to_string();
266 let path = self.decision_path(&decision_id);
267
268 let json_data = serde_json::to_vec(decision)
269 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
270
271 self.upload_object(&path, &json_data).await?;
272
273 Ok(decision_id)
274 }
275
276 async fn load(&self, snapshot_id: &str) -> Result<Snapshot, StorageError> {
277 let path = self.snapshot_path(snapshot_id);
278 let data = self.download_object(&path).await?;
279
280 let snapshot: Snapshot = serde_json::from_slice(&data)
281 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
282
283 Ok(snapshot)
284 }
285
286 async fn load_decision(&self, decision_id: &str) -> Result<DecisionSnapshot, StorageError> {
287 let path = self.decision_path(decision_id);
288 let data = self.download_object(&path).await?;
289
290 let decision: DecisionSnapshot = serde_json::from_slice(&data)
291 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
292
293 Ok(decision)
294 }
295
296 async fn query(&self, query: SnapshotQuery) -> Result<Vec<Snapshot>, StorageError> {
297 let paths = self.list_objects("snapshots/").await?;
299
300 let mut snapshots = Vec::new();
301 let mut count = 0;
302 let offset = query.offset.unwrap_or(0);
303 let limit = query.limit.unwrap_or(usize::MAX);
304
305 for path in paths {
306 if let Some(filename) = path.split('/').next_back() {
307 if let Some(snapshot_id) = filename.strip_suffix(".json") {
308 match self.load(snapshot_id).await {
310 Ok(snapshot) => {
311 if self.matches_query(&snapshot, &query) {
313 if count >= offset {
314 snapshots.push(snapshot);
315 if snapshots.len() >= limit {
316 break;
317 }
318 }
319 count += 1;
320 }
321 }
322 Err(_) => continue, }
324 }
325 }
326 }
327
328 snapshots.sort_by(|a, b| b.metadata.timestamp.cmp(&a.metadata.timestamp));
330
331 Ok(snapshots)
332 }
333
334 async fn delete(&self, _snapshot_id: &str) -> Result<bool, StorageError> {
335 Err(StorageError::PermissionDenied(
339 "LakeFS doesn't support direct object deletion. Use branch operations instead."
340 .to_string(),
341 ))
342 }
343
344 async fn flush(&self) -> Result<FlushResult, StorageError> {
345 let pending_snapshots = {
346 let mut pending = self.pending_writes.lock().unwrap();
347 let snapshots = pending.clone();
348 pending.clear();
349 snapshots
350 };
351
352 let mut bytes_written = 0;
353
354 for snapshot in &pending_snapshots {
356 let snapshot_id = snapshot.metadata.snapshot_id.to_string();
357 let path = self.snapshot_path(&snapshot_id);
358
359 let json_data = serde_json::to_vec(snapshot)
360 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
361
362 bytes_written += json_data.len();
363
364 self.upload_object(&path, &json_data).await?;
365
366 for decision in &snapshot.decisions {
368 let decision_id = decision.metadata.snapshot_id.to_string();
369 let decision_path = self.decision_path(&decision_id);
370
371 let decision_data = serde_json::to_vec(decision)
372 .map_err(|e| StorageError::SerializationError(e.to_string()))?;
373
374 bytes_written += decision_data.len();
375 self.upload_object(&decision_path, &decision_data).await?;
376 }
377 }
378
379 let commit_message = format!("Briefcase AI flush: {} snapshots", pending_snapshots.len());
381 let commit_id = self.create_commit(&commit_message).await?;
382
383 Ok(FlushResult {
384 snapshots_written: pending_snapshots.len(),
385 bytes_written,
386 checkpoint_id: Some(commit_id),
387 })
388 }
389
390 async fn health_check(&self) -> Result<bool, StorageError> {
391 let url = format!("{}/repositories/{}", self.endpoint, self.repository);
392
393 let response =
394 self.client.get(&url).send().await.map_err(|e| {
395 StorageError::ConnectionError(format!("Health check failed: {}", e))
396 })?;
397
398 Ok(response.status().is_success())
399 }
400}
401
402#[cfg(all(feature = "async", feature = "networking"))]
403impl LakeFSBackend {
404 fn matches_query(&self, snapshot: &Snapshot, query: &SnapshotQuery) -> bool {
406 if let Some(start_time) = query.start_time {
408 if snapshot.metadata.timestamp < start_time {
409 return false;
410 }
411 }
412
413 if let Some(end_time) = query.end_time {
414 if snapshot.metadata.timestamp > end_time {
415 return false;
416 }
417 }
418
419 if query.function_name.is_some()
421 || query.module_name.is_some()
422 || query.model_name.is_some()
423 || query.tags.is_some()
424 {
425 let mut found_match = false;
426
427 for decision in &snapshot.decisions {
428 let mut decision_matches = true;
429
430 if let Some(function_name) = &query.function_name {
431 if decision.function_name != *function_name {
432 decision_matches = false;
433 }
434 }
435
436 if let Some(module_name) = &query.module_name {
437 if decision.module_name.as_ref() != Some(module_name) {
438 decision_matches = false;
439 }
440 }
441
442 if let Some(model_name) = &query.model_name {
443 if let Some(model_params) = &decision.model_parameters {
444 if model_params.model_name != *model_name {
445 decision_matches = false;
446 }
447 } else {
448 decision_matches = false;
449 }
450 }
451
452 if let Some(query_tags) = &query.tags {
453 for (key, value) in query_tags {
454 if decision.tags.get(key) != Some(value) {
455 decision_matches = false;
456 break;
457 }
458 }
459 }
460
461 if decision_matches {
462 found_match = true;
463 break;
464 }
465 }
466
467 if !found_match {
468 return false;
469 }
470 }
471
472 true
473 }
474}
475
476#[derive(Debug, Clone)]
477pub struct LakeFSConfig {
478 pub endpoint: String,
479 pub repository: String,
480 pub branch: String,
481 pub access_key: String,
482 pub secret_key: String,
483}
484
485impl LakeFSConfig {
486 pub fn new(
487 endpoint: impl Into<String>,
488 repository: impl Into<String>,
489 branch: impl Into<String>,
490 access_key: impl Into<String>,
491 secret_key: impl Into<String>,
492 ) -> Self {
493 Self {
494 endpoint: endpoint.into(),
495 repository: repository.into(),
496 branch: branch.into(),
497 access_key: access_key.into(),
498 secret_key: secret_key.into(),
499 }
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506 use crate::models::*;
507 use serde_json::json;
508
509 fn create_test_config() -> LakeFSConfig {
510 LakeFSConfig::new(
511 "http://localhost:8000",
512 "briefcase-test",
513 "main",
514 "test_access_key",
515 "test_secret_key",
516 )
517 }
518
519 async fn create_test_snapshot() -> Snapshot {
520 let input = Input::new("test_input", json!("value"), "string");
521 let output = Output::new("test_output", json!("result"), "string");
522 let model_params = ModelParameters::new("gpt-4");
523
524 let decision = DecisionSnapshot::new("test_function")
525 .with_module("test_module")
526 .add_input(input)
527 .add_output(output)
528 .with_model_parameters(model_params)
529 .add_tag("env", "test");
530
531 let mut snapshot = Snapshot::new(SnapshotType::Session);
532 snapshot.add_decision(decision);
533 snapshot
534 }
535
536 #[tokio::test]
537 async fn test_lakefs_config_creation() {
538 let config = create_test_config();
539 assert_eq!(config.endpoint, "http://localhost:8000");
540 assert_eq!(config.repository, "briefcase-test");
541 assert_eq!(config.branch, "main");
542 }
543
544 #[tokio::test]
545 async fn test_object_paths() {
546 let config = create_test_config();
547 let backend = LakeFSBackend::new(config).unwrap();
548
549 let snapshot_id = "test-snapshot-123";
550 let decision_id = "test-decision-456";
551
552 assert_eq!(
553 backend.snapshot_path(snapshot_id),
554 "snapshots/test-snapshot-123.json"
555 );
556 assert_eq!(
557 backend.decision_path(decision_id),
558 "decisions/test-decision-456.json"
559 );
560 }
561
562 #[tokio::test]
563 async fn test_query_matching() {
564 let config = create_test_config();
565 let backend = LakeFSBackend::new(config).unwrap();
566 let snapshot = create_test_snapshot().await;
567
568 let query = SnapshotQuery::new().with_function_name("test_function");
570 assert!(backend.matches_query(&snapshot, &query));
571
572 let query = SnapshotQuery::new().with_function_name("other_function");
573 assert!(!backend.matches_query(&snapshot, &query));
574
575 let query = SnapshotQuery::new().with_tag("env", "test");
577 assert!(backend.matches_query(&snapshot, &query));
578
579 let query = SnapshotQuery::new().with_tag("env", "prod");
580 assert!(!backend.matches_query(&snapshot, &query));
581
582 let query = SnapshotQuery::new().with_model_name("gpt-4");
584 assert!(backend.matches_query(&snapshot, &query));
585
586 let query = SnapshotQuery::new().with_model_name("claude-3");
587 assert!(!backend.matches_query(&snapshot, &query));
588 }
589
590 #[tokio::test]
591 async fn test_pending_writes() {
592 let config = create_test_config();
593 let backend = LakeFSBackend::new(config).unwrap();
594 let snapshot = create_test_snapshot().await;
595
596 let snapshot_id = backend.save(&snapshot).await.unwrap();
598 assert_eq!(snapshot_id, snapshot.metadata.snapshot_id.to_string());
599
600 {
602 let pending = backend.pending_writes.lock().unwrap();
603 assert_eq!(pending.len(), 1);
604 }
605 }
606
607 }