oxirs_chat/memory_optimization/
streaming.rs1use anyhow::Result;
4use std::marker::PhantomData;
5
6pub struct StreamProcessor<T> {
8 chunk_size: usize,
9 _phantom: PhantomData<T>,
10}
11
12impl<T> StreamProcessor<T> {
13 pub fn new(chunk_size: usize) -> Self {
14 Self {
15 chunk_size,
16 _phantom: PhantomData,
17 }
18 }
19
20 pub fn process_chunks<F, R>(&self, data: Vec<T>, mut processor: F) -> Result<Vec<R>>
22 where
23 F: FnMut(&[T]) -> Result<Vec<R>>,
24 {
25 let mut results = Vec::new();
26
27 for chunk in data.chunks(self.chunk_size) {
28 let chunk_results = processor(chunk)?;
29 results.extend(chunk_results);
30 }
31
32 Ok(results)
33 }
34
35 pub fn process_stream<I, F, R>(&self, iterator: I, mut processor: F) -> Result<Vec<R>>
37 where
38 I: Iterator<Item = T>,
39 F: FnMut(Vec<T>) -> Result<Vec<R>>,
40 {
41 let mut results = Vec::new();
42 let mut buffer = Vec::with_capacity(self.chunk_size);
43
44 for item in iterator {
45 buffer.push(item);
46
47 if buffer.len() >= self.chunk_size {
48 let chunk_results = processor(buffer)?;
49 results.extend(chunk_results);
50 buffer = Vec::with_capacity(self.chunk_size);
51 }
52 }
53
54 if !buffer.is_empty() {
56 let chunk_results = processor(buffer)?;
57 results.extend(chunk_results);
58 }
59
60 Ok(results)
61 }
62}
63
64pub struct ChunkProcessor {
66 max_memory_mb: usize,
67 estimated_item_size: usize,
68}
69
70impl ChunkProcessor {
71 pub fn new(max_memory_mb: usize, estimated_item_size: usize) -> Self {
72 Self {
73 max_memory_mb,
74 estimated_item_size,
75 }
76 }
77
78 pub fn optimal_chunk_size(&self) -> usize {
80 let max_bytes = self.max_memory_mb * 1024 * 1024;
81 let chunk_size = max_bytes / self.estimated_item_size;
82 chunk_size.max(1) }
84
85 pub fn process_embeddings<F>(
87 &self,
88 texts: Vec<String>,
89 mut embed_fn: F,
90 ) -> Result<Vec<Vec<f32>>>
91 where
92 F: FnMut(&[String]) -> Result<Vec<Vec<f32>>>,
93 {
94 let chunk_size = self.optimal_chunk_size();
95 let mut all_embeddings = Vec::with_capacity(texts.len());
96
97 for chunk in texts.chunks(chunk_size) {
98 let chunk_embeddings = embed_fn(chunk)?;
99 all_embeddings.extend(chunk_embeddings);
100 }
101
102 Ok(all_embeddings)
103 }
104}
105
106pub struct StreamingAggregator<T> {
108 state: T,
109}
110
111impl<T> StreamingAggregator<T> {
112 pub fn new(initial_state: T) -> Self {
113 Self {
114 state: initial_state,
115 }
116 }
117
118 pub fn aggregate<I, F>(&mut self, iterator: I, mut aggregator: F) -> &T
120 where
121 I: Iterator,
122 F: FnMut(&mut T, I::Item),
123 {
124 for item in iterator {
125 aggregator(&mut self.state, item);
126 }
127 &self.state
128 }
129
130 pub fn state(&self) -> &T {
131 &self.state
132 }
133}
134
135#[cfg(test)]
136mod tests {
137 use super::*;
138
139 #[test]
140 fn test_stream_processor_chunks() {
141 let processor = StreamProcessor::new(10);
142 let data: Vec<i32> = (0..25).collect();
143
144 let results = processor
145 .process_chunks(data, |chunk| Ok(chunk.iter().map(|x| x * 2).collect()))
146 .expect("should succeed");
147
148 assert_eq!(results.len(), 25);
149 assert_eq!(results[0], 0);
150 assert_eq!(results[24], 48);
151 }
152
153 #[test]
154 fn test_stream_processor_iterator() {
155 let processor = StreamProcessor::new(5);
156 let data = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10];
157
158 let results = processor
159 .process_stream(data.into_iter(), |chunk| {
160 Ok(chunk.iter().map(|x| x * 2).collect())
161 })
162 .expect("should succeed");
163
164 assert_eq!(results.len(), 10);
165 assert_eq!(results[0], 2);
166 assert_eq!(results[9], 20);
167 }
168
169 #[test]
170 fn test_chunk_processor_optimal_size() {
171 let processor = ChunkProcessor::new(100, 1024);
173 let chunk_size = processor.optimal_chunk_size();
174
175 assert_eq!(chunk_size, 102400); }
177
178 #[test]
179 fn test_streaming_aggregator() {
180 let mut aggregator = StreamingAggregator::new(0i32);
181
182 let data = vec![1, 2, 3, 4, 5];
183 let result = aggregator.aggregate(data.into_iter(), |state, item| {
184 *state += item;
185 });
186
187 assert_eq!(*result, 15);
188 }
189
190 #[test]
191 fn test_chunk_processor_embeddings() {
192 let processor = ChunkProcessor::new(10, 1000);
193 let texts: Vec<String> = (0..50).map(|i| format!("text_{}", i)).collect();
194
195 let result = processor.process_embeddings(texts, |chunk| {
196 Ok(chunk.iter().map(|_| vec![1.0, 2.0, 3.0]).collect())
198 });
199
200 assert!(result.is_ok());
201 let embeddings = result.expect("should succeed");
202 assert_eq!(embeddings.len(), 50);
203 }
204}