ipfrs_core/
integration.rs

1//! Integration utilities combining multiple ipfrs-core features.
2//!
3//! This module provides high-level utilities that combine tensor operations,
4//! Arrow integration, and content-addressed storage for common workflows.
5
6use crate::arrow::{arrow_to_tensor_block, TensorBlockArrowExt};
7use crate::error::Result;
8use crate::hash::global_hash_registry;
9use crate::tensor::{TensorBlock, TensorShape};
10use arrow_array::{ArrayRef, RecordBatch};
11use arrow_schema::Schema;
12use multihash_codetable::Code;
13use std::sync::Arc;
14
15/// Batch processor for tensor blocks
16pub struct TensorBatchProcessor {
17    hash_algo: Code,
18}
19
20impl TensorBatchProcessor {
21    /// Create a new batch processor with the specified hash algorithm
22    pub fn new(hash_algo: Code) -> Self {
23        Self { hash_algo }
24    }
25
26    /// Process multiple tensors and generate CIDs with hardware acceleration
27    pub fn process_batch(&self, tensors: &[TensorBlock]) -> Result<Vec<String>> {
28        let registry = global_hash_registry();
29        let mut cids = Vec::with_capacity(tensors.len());
30
31        for tensor in tensors {
32            let data = tensor.data();
33            let _hash = registry.digest(self.hash_algo, data)?;
34            cids.push(tensor.cid().to_string());
35        }
36
37        Ok(cids)
38    }
39
40    /// Convert multiple tensors to an Arrow RecordBatch
41    pub fn to_arrow_batch(&self, tensors: Vec<(&str, &TensorBlock)>) -> Result<RecordBatch> {
42        let mut fields = Vec::new();
43        let mut arrays: Vec<ArrayRef> = Vec::new();
44
45        for (name, tensor) in tensors {
46            fields.push(tensor.to_arrow_field(name));
47            arrays.push(tensor.to_arrow_array()?);
48        }
49
50        let schema = Arc::new(Schema::new(fields));
51        RecordBatch::try_new(schema, arrays).map_err(|e| {
52            crate::error::Error::InvalidInput(format!("Failed to create RecordBatch: {}", e))
53        })
54    }
55
56    /// Process Arrow RecordBatch and convert to tensor blocks
57    pub fn from_arrow_batch(
58        &self,
59        batch: &RecordBatch,
60        shapes: Vec<TensorShape>,
61    ) -> Result<Vec<TensorBlock>> {
62        if batch.num_columns() != shapes.len() {
63            return Err(crate::error::Error::InvalidInput(format!(
64                "Column count {} doesn't match shape count {}",
65                batch.num_columns(),
66                shapes.len()
67            )));
68        }
69
70        let mut tensors = Vec::with_capacity(batch.num_columns());
71
72        for (col_idx, shape) in shapes.into_iter().enumerate() {
73            let array = batch.column(col_idx);
74            let tensor = arrow_to_tensor_block(array.as_ref(), shape)?;
75            tensors.push(tensor);
76        }
77
78        Ok(tensors)
79    }
80}
81
82impl Default for TensorBatchProcessor {
83    fn default() -> Self {
84        Self {
85            hash_algo: Code::Sha2_256,
86        }
87    }
88}
89
90/// Utility for tensor deduplication using content-addressed storage
91pub struct TensorDeduplicator {
92    seen_cids: std::collections::HashMap<String, usize>,
93}
94
95impl TensorDeduplicator {
96    /// Create a new deduplicator
97    pub fn new() -> Self {
98        Self {
99            seen_cids: std::collections::HashMap::new(),
100        }
101    }
102
103    /// Check if a tensor has been seen before (by CID)
104    /// Returns the index of the first occurrence if found
105    pub fn check(&mut self, tensor: &TensorBlock) -> Option<usize> {
106        let cid = tensor.cid().to_string();
107        self.seen_cids.get(&cid).copied()
108    }
109
110    /// Register a tensor and return its index
111    pub fn register(&mut self, tensor: &TensorBlock) -> usize {
112        let cid = tensor.cid().to_string();
113        let idx = self.seen_cids.len();
114        self.seen_cids.entry(cid).or_insert(idx);
115        idx
116    }
117
118    /// Get the number of unique tensors
119    pub fn unique_count(&self) -> usize {
120        self.seen_cids.len()
121    }
122
123    /// Get deduplication statistics
124    pub fn stats(&self) -> DeduplicationStats {
125        DeduplicationStats {
126            unique_tensors: self.seen_cids.len(),
127            total_checked: self.seen_cids.len(),
128        }
129    }
130}
131
132impl Default for TensorDeduplicator {
133    fn default() -> Self {
134        Self::new()
135    }
136}
137
138/// Statistics for tensor deduplication
139#[derive(Debug, Clone)]
140pub struct DeduplicationStats {
141    /// Number of unique tensors seen (distinct CIDs)
142    pub unique_tensors: usize,
143    /// Total number of tensors checked for deduplication
144    pub total_checked: usize,
145}
146
147impl DeduplicationStats {
148    /// Calculate the deduplication ratio
149    pub fn dedup_ratio(&self) -> f64 {
150        if self.total_checked == 0 {
151            return 0.0;
152        }
153        self.unique_tensors as f64 / self.total_checked as f64
154    }
155}
156
157/// High-level API for tensor storage and retrieval
158pub struct TensorStore {
159    tensors: std::collections::HashMap<String, TensorBlock>,
160}
161
162impl TensorStore {
163    /// Create a new tensor store
164    pub fn new() -> Self {
165        Self {
166            tensors: std::collections::HashMap::new(),
167        }
168    }
169
170    /// Store a tensor and return its CID
171    pub fn store(&mut self, tensor: TensorBlock) -> String {
172        let cid = tensor.cid().to_string();
173        self.tensors.insert(cid.clone(), tensor);
174        cid
175    }
176
177    /// Retrieve a tensor by CID
178    pub fn get(&self, cid: &str) -> Option<&TensorBlock> {
179        self.tensors.get(cid)
180    }
181
182    /// Check if a tensor exists
183    pub fn contains(&self, cid: &str) -> bool {
184        self.tensors.contains_key(cid)
185    }
186
187    /// Get the number of stored tensors
188    pub fn len(&self) -> usize {
189        self.tensors.len()
190    }
191
192    /// Check if the store is empty
193    pub fn is_empty(&self) -> bool {
194        self.tensors.is_empty()
195    }
196
197    /// List all CIDs in the store
198    pub fn list_cids(&self) -> Vec<String> {
199        self.tensors.keys().cloned().collect()
200    }
201
202    /// Clear all tensors from the store
203    pub fn clear(&mut self) {
204        self.tensors.clear();
205    }
206}
207
208impl Default for TensorStore {
209    fn default() -> Self {
210        Self::new()
211    }
212}
213
214#[cfg(test)]
215mod tests {
216    use super::*;
217
218    #[test]
219    fn test_batch_processor() {
220        let processor = TensorBatchProcessor::default();
221
222        // Create test tensors
223        let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
224        let data2 = vec![5.0f32, 6.0, 7.0, 8.0];
225
226        let tensor1 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![2, 2])).unwrap();
227        let tensor2 = TensorBlock::from_f32_slice(&data2, TensorShape::new(vec![2, 2])).unwrap();
228
229        let cids = processor.process_batch(&[tensor1, tensor2]).unwrap();
230        assert_eq!(cids.len(), 2);
231        assert_ne!(cids[0], cids[1]); // Different data should have different CIDs
232    }
233
234    #[test]
235    fn test_arrow_batch_roundtrip() {
236        let processor = TensorBatchProcessor::default();
237
238        // Create test tensors
239        let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
240        let data2 = vec![5.0f32, 6.0, 7.0, 8.0];
241
242        let tensor1 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![4])).unwrap();
243        let tensor2 = TensorBlock::from_f32_slice(&data2, TensorShape::new(vec![4])).unwrap();
244
245        // Convert to Arrow batch
246        let batch = processor
247            .to_arrow_batch(vec![("t1", &tensor1), ("t2", &tensor2)])
248            .unwrap();
249
250        assert_eq!(batch.num_columns(), 2);
251        assert_eq!(batch.num_rows(), 4);
252
253        // Convert back to tensors
254        let shapes = vec![TensorShape::new(vec![4]), TensorShape::new(vec![4])];
255        let recovered = processor.from_arrow_batch(&batch, shapes).unwrap();
256
257        assert_eq!(recovered.len(), 2);
258        assert_eq!(recovered[0].to_f32_vec().unwrap(), data1);
259        assert_eq!(recovered[1].to_f32_vec().unwrap(), data2);
260    }
261
262    #[test]
263    fn test_tensor_deduplicator() {
264        let mut dedup = TensorDeduplicator::new();
265
266        let data = vec![1.0f32, 2.0, 3.0, 4.0];
267        let tensor1 = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![4])).unwrap();
268        let tensor2 = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![4])).unwrap(); // Same data
269
270        // First tensor should be new
271        assert_eq!(dedup.check(&tensor1), None);
272        let idx1 = dedup.register(&tensor1);
273
274        // Second tensor should be duplicate
275        assert_eq!(dedup.check(&tensor2), Some(idx1));
276
277        assert_eq!(dedup.unique_count(), 1);
278    }
279
280    #[test]
281    fn test_tensor_store() {
282        let mut store = TensorStore::new();
283        assert!(store.is_empty());
284
285        let data = vec![1.0f32, 2.0, 3.0, 4.0];
286        let tensor = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![4])).unwrap();
287
288        // Store tensor
289        let cid = store.store(tensor.clone());
290        assert_eq!(store.len(), 1);
291        assert!(store.contains(&cid));
292
293        // Retrieve tensor
294        let retrieved = store.get(&cid).unwrap();
295        assert_eq!(retrieved.to_f32_vec().unwrap(), data);
296
297        // List CIDs
298        let cids = store.list_cids();
299        assert_eq!(cids.len(), 1);
300        assert_eq!(cids[0], cid);
301
302        // Clear store
303        store.clear();
304        assert!(store.is_empty());
305    }
306
307    #[test]
308    fn test_deduplication_stats() {
309        let mut dedup = TensorDeduplicator::new();
310
311        let data1 = vec![1.0f32, 2.0];
312        let data2 = vec![3.0f32, 4.0];
313
314        let t1 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![2])).unwrap();
315        let t2 = TensorBlock::from_f32_slice(&data2, TensorShape::new(vec![2])).unwrap();
316        let t3 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![2])).unwrap(); // Duplicate of t1
317
318        dedup.register(&t1);
319        dedup.register(&t2);
320        let _ = dedup.check(&t3);
321
322        let stats = dedup.stats();
323        assert_eq!(stats.unique_tensors, 2);
324    }
325}