1use crate::arrow::{arrow_to_tensor_block, TensorBlockArrowExt};
7use crate::error::Result;
8use crate::hash::global_hash_registry;
9use crate::tensor::{TensorBlock, TensorShape};
10use arrow_array::{ArrayRef, RecordBatch};
11use arrow_schema::Schema;
12use multihash_codetable::Code;
13use std::sync::Arc;
14
15pub struct TensorBatchProcessor {
17 hash_algo: Code,
18}
19
20impl TensorBatchProcessor {
21 pub fn new(hash_algo: Code) -> Self {
23 Self { hash_algo }
24 }
25
26 pub fn process_batch(&self, tensors: &[TensorBlock]) -> Result<Vec<String>> {
28 let registry = global_hash_registry();
29 let mut cids = Vec::with_capacity(tensors.len());
30
31 for tensor in tensors {
32 let data = tensor.data();
33 let _hash = registry.digest(self.hash_algo, data)?;
34 cids.push(tensor.cid().to_string());
35 }
36
37 Ok(cids)
38 }
39
40 pub fn to_arrow_batch(&self, tensors: Vec<(&str, &TensorBlock)>) -> Result<RecordBatch> {
42 let mut fields = Vec::new();
43 let mut arrays: Vec<ArrayRef> = Vec::new();
44
45 for (name, tensor) in tensors {
46 fields.push(tensor.to_arrow_field(name));
47 arrays.push(tensor.to_arrow_array()?);
48 }
49
50 let schema = Arc::new(Schema::new(fields));
51 RecordBatch::try_new(schema, arrays).map_err(|e| {
52 crate::error::Error::InvalidInput(format!("Failed to create RecordBatch: {}", e))
53 })
54 }
55
56 pub fn from_arrow_batch(
58 &self,
59 batch: &RecordBatch,
60 shapes: Vec<TensorShape>,
61 ) -> Result<Vec<TensorBlock>> {
62 if batch.num_columns() != shapes.len() {
63 return Err(crate::error::Error::InvalidInput(format!(
64 "Column count {} doesn't match shape count {}",
65 batch.num_columns(),
66 shapes.len()
67 )));
68 }
69
70 let mut tensors = Vec::with_capacity(batch.num_columns());
71
72 for (col_idx, shape) in shapes.into_iter().enumerate() {
73 let array = batch.column(col_idx);
74 let tensor = arrow_to_tensor_block(array.as_ref(), shape)?;
75 tensors.push(tensor);
76 }
77
78 Ok(tensors)
79 }
80}
81
82impl Default for TensorBatchProcessor {
83 fn default() -> Self {
84 Self {
85 hash_algo: Code::Sha2_256,
86 }
87 }
88}
89
90pub struct TensorDeduplicator {
92 seen_cids: std::collections::HashMap<String, usize>,
93}
94
95impl TensorDeduplicator {
96 pub fn new() -> Self {
98 Self {
99 seen_cids: std::collections::HashMap::new(),
100 }
101 }
102
103 pub fn check(&mut self, tensor: &TensorBlock) -> Option<usize> {
106 let cid = tensor.cid().to_string();
107 self.seen_cids.get(&cid).copied()
108 }
109
110 pub fn register(&mut self, tensor: &TensorBlock) -> usize {
112 let cid = tensor.cid().to_string();
113 let idx = self.seen_cids.len();
114 self.seen_cids.entry(cid).or_insert(idx);
115 idx
116 }
117
118 pub fn unique_count(&self) -> usize {
120 self.seen_cids.len()
121 }
122
123 pub fn stats(&self) -> DeduplicationStats {
125 DeduplicationStats {
126 unique_tensors: self.seen_cids.len(),
127 total_checked: self.seen_cids.len(),
128 }
129 }
130}
131
132impl Default for TensorDeduplicator {
133 fn default() -> Self {
134 Self::new()
135 }
136}
137
138#[derive(Debug, Clone)]
140pub struct DeduplicationStats {
141 pub unique_tensors: usize,
143 pub total_checked: usize,
145}
146
147impl DeduplicationStats {
148 pub fn dedup_ratio(&self) -> f64 {
150 if self.total_checked == 0 {
151 return 0.0;
152 }
153 self.unique_tensors as f64 / self.total_checked as f64
154 }
155}
156
157pub struct TensorStore {
159 tensors: std::collections::HashMap<String, TensorBlock>,
160}
161
162impl TensorStore {
163 pub fn new() -> Self {
165 Self {
166 tensors: std::collections::HashMap::new(),
167 }
168 }
169
170 pub fn store(&mut self, tensor: TensorBlock) -> String {
172 let cid = tensor.cid().to_string();
173 self.tensors.insert(cid.clone(), tensor);
174 cid
175 }
176
177 pub fn get(&self, cid: &str) -> Option<&TensorBlock> {
179 self.tensors.get(cid)
180 }
181
182 pub fn contains(&self, cid: &str) -> bool {
184 self.tensors.contains_key(cid)
185 }
186
187 pub fn len(&self) -> usize {
189 self.tensors.len()
190 }
191
192 pub fn is_empty(&self) -> bool {
194 self.tensors.is_empty()
195 }
196
197 pub fn list_cids(&self) -> Vec<String> {
199 self.tensors.keys().cloned().collect()
200 }
201
202 pub fn clear(&mut self) {
204 self.tensors.clear();
205 }
206}
207
208impl Default for TensorStore {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214#[cfg(test)]
215mod tests {
216 use super::*;
217
218 #[test]
219 fn test_batch_processor() {
220 let processor = TensorBatchProcessor::default();
221
222 let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
224 let data2 = vec![5.0f32, 6.0, 7.0, 8.0];
225
226 let tensor1 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![2, 2])).unwrap();
227 let tensor2 = TensorBlock::from_f32_slice(&data2, TensorShape::new(vec![2, 2])).unwrap();
228
229 let cids = processor.process_batch(&[tensor1, tensor2]).unwrap();
230 assert_eq!(cids.len(), 2);
231 assert_ne!(cids[0], cids[1]); }
233
234 #[test]
235 fn test_arrow_batch_roundtrip() {
236 let processor = TensorBatchProcessor::default();
237
238 let data1 = vec![1.0f32, 2.0, 3.0, 4.0];
240 let data2 = vec![5.0f32, 6.0, 7.0, 8.0];
241
242 let tensor1 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![4])).unwrap();
243 let tensor2 = TensorBlock::from_f32_slice(&data2, TensorShape::new(vec![4])).unwrap();
244
245 let batch = processor
247 .to_arrow_batch(vec![("t1", &tensor1), ("t2", &tensor2)])
248 .unwrap();
249
250 assert_eq!(batch.num_columns(), 2);
251 assert_eq!(batch.num_rows(), 4);
252
253 let shapes = vec![TensorShape::new(vec![4]), TensorShape::new(vec![4])];
255 let recovered = processor.from_arrow_batch(&batch, shapes).unwrap();
256
257 assert_eq!(recovered.len(), 2);
258 assert_eq!(recovered[0].to_f32_vec().unwrap(), data1);
259 assert_eq!(recovered[1].to_f32_vec().unwrap(), data2);
260 }
261
262 #[test]
263 fn test_tensor_deduplicator() {
264 let mut dedup = TensorDeduplicator::new();
265
266 let data = vec![1.0f32, 2.0, 3.0, 4.0];
267 let tensor1 = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![4])).unwrap();
268 let tensor2 = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![4])).unwrap(); assert_eq!(dedup.check(&tensor1), None);
272 let idx1 = dedup.register(&tensor1);
273
274 assert_eq!(dedup.check(&tensor2), Some(idx1));
276
277 assert_eq!(dedup.unique_count(), 1);
278 }
279
280 #[test]
281 fn test_tensor_store() {
282 let mut store = TensorStore::new();
283 assert!(store.is_empty());
284
285 let data = vec![1.0f32, 2.0, 3.0, 4.0];
286 let tensor = TensorBlock::from_f32_slice(&data, TensorShape::new(vec![4])).unwrap();
287
288 let cid = store.store(tensor.clone());
290 assert_eq!(store.len(), 1);
291 assert!(store.contains(&cid));
292
293 let retrieved = store.get(&cid).unwrap();
295 assert_eq!(retrieved.to_f32_vec().unwrap(), data);
296
297 let cids = store.list_cids();
299 assert_eq!(cids.len(), 1);
300 assert_eq!(cids[0], cid);
301
302 store.clear();
304 assert!(store.is_empty());
305 }
306
307 #[test]
308 fn test_deduplication_stats() {
309 let mut dedup = TensorDeduplicator::new();
310
311 let data1 = vec![1.0f32, 2.0];
312 let data2 = vec![3.0f32, 4.0];
313
314 let t1 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![2])).unwrap();
315 let t2 = TensorBlock::from_f32_slice(&data2, TensorShape::new(vec![2])).unwrap();
316 let t3 = TensorBlock::from_f32_slice(&data1, TensorShape::new(vec![2])).unwrap(); dedup.register(&t1);
319 dedup.register(&t2);
320 let _ = dedup.check(&t3);
321
322 let stats = dedup.stats();
323 assert_eq!(stats.unique_tensors, 2);
324 }
325}