kaish_kernel/scheduler/
stream.rs1use std::collections::VecDeque;
10use std::sync::Arc;
11use tokio::sync::RwLock;
12
13pub const DEFAULT_STREAM_MAX_SIZE: usize = 10 * 1024 * 1024;
15
16#[derive(Clone)]
36pub struct BoundedStream {
37 inner: Arc<RwLock<BoundedStreamInner>>,
38}
39
40struct BoundedStreamInner {
41 buffer: VecDeque<u8>,
43 max_size: usize,
45 total_written: u64,
47 bytes_evicted: u64,
49 closed: bool,
51}
52
53impl BoundedStream {
54 pub fn new(max_size: usize) -> Self {
56 Self {
57 inner: Arc::new(RwLock::new(BoundedStreamInner {
58 buffer: VecDeque::with_capacity(max_size.min(8192)), max_size,
60 total_written: 0,
61 bytes_evicted: 0,
62 closed: false,
63 })),
64 }
65 }
66
67 pub fn default_size() -> Self {
69 Self::new(DEFAULT_STREAM_MAX_SIZE)
70 }
71
72 pub async fn write(&self, data: &[u8]) {
77 let mut inner = self.inner.write().await;
78
79 if inner.closed {
80 return;
81 }
82
83 inner.total_written += data.len() as u64;
84
85 if data.len() >= inner.max_size {
87 let start = data.len() - inner.max_size;
88 inner.bytes_evicted += inner.buffer.len() as u64 + start as u64;
89 inner.buffer.clear();
90 inner.buffer.extend(&data[start..]);
91 return;
92 }
93
94 let needed = data.len();
96 let available = inner.max_size.saturating_sub(inner.buffer.len());
97
98 if needed > available {
99 let to_evict = needed - available;
100 let actual_evict = to_evict.min(inner.buffer.len());
101 inner.buffer.drain(..actual_evict);
102 inner.bytes_evicted += actual_evict as u64;
103 }
104
105 inner.buffer.extend(data);
107 }
108
109 pub async fn read(&self) -> Vec<u8> {
114 let inner = self.inner.read().await;
115 inner.buffer.iter().copied().collect()
116 }
117
118 pub async fn read_string(&self) -> String {
120 let data = self.read().await;
121 String::from_utf8_lossy(&data).into_owned()
122 }
123
124 pub async fn close(&self) {
128 let mut inner = self.inner.write().await;
129 inner.closed = true;
130 }
131
132 pub async fn is_closed(&self) -> bool {
134 let inner = self.inner.read().await;
135 inner.closed
136 }
137
138 pub async fn len(&self) -> usize {
140 let inner = self.inner.read().await;
141 inner.buffer.len()
142 }
143
144 pub async fn is_empty(&self) -> bool {
146 self.len().await == 0
147 }
148
149 pub async fn stats(&self) -> StreamStats {
151 let inner = self.inner.read().await;
152 StreamStats {
153 current_size: inner.buffer.len(),
154 max_size: inner.max_size,
155 total_written: inner.total_written,
156 bytes_evicted: inner.bytes_evicted,
157 closed: inner.closed,
158 }
159 }
160}
161
162impl std::fmt::Debug for BoundedStream {
163 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
164 f.debug_struct("BoundedStream")
165 .field("inner", &"<locked>")
166 .finish()
167 }
168}
169
170#[derive(Debug, Clone)]
172pub struct StreamStats {
173 pub current_size: usize,
175 pub max_size: usize,
177 pub total_written: u64,
179 pub bytes_evicted: u64,
181 pub closed: bool,
183}
184
185pub async fn drain_to_stream<R>(mut reader: R, stream: Arc<BoundedStream>)
190where
191 R: tokio::io::AsyncRead + Unpin,
192{
193 use tokio::io::AsyncReadExt;
194
195 let mut buf = [0u8; 8192];
196 loop {
197 match reader.read(&mut buf).await {
198 Ok(0) => break, Ok(n) => stream.write(&buf[..n]).await,
200 Err(e) => {
201 tracing::warn!("drain_to_stream read error: {}", e);
202 break;
203 }
204 }
205 }
206 stream.close().await;
207}
208
209#[cfg(test)]
210mod tests {
211 use super::*;
212
213 #[tokio::test]
214 async fn test_basic_write_read() {
215 let stream = BoundedStream::new(100);
216 stream.write(b"hello").await;
217 assert_eq!(stream.read().await, b"hello");
218 }
219
220 #[tokio::test]
221 async fn test_multiple_writes() {
222 let stream = BoundedStream::new(100);
223 stream.write(b"hello ").await;
224 stream.write(b"world").await;
225 assert_eq!(stream.read().await, b"hello world");
226 }
227
228 #[tokio::test]
229 async fn test_eviction_on_overflow() {
230 let stream = BoundedStream::new(10);
231 stream.write(b"12345").await;
232 stream.write(b"67890").await;
233 assert_eq!(stream.len().await, 10);
234
235 stream.write(b"ABCDE").await;
237 assert_eq!(stream.read().await, b"67890ABCDE");
238
239 let stats = stream.stats().await;
240 assert_eq!(stats.bytes_evicted, 5);
241 assert_eq!(stats.total_written, 15);
242 }
243
244 #[tokio::test]
245 async fn test_large_write_exceeds_buffer() {
246 let stream = BoundedStream::new(10);
247 stream.write(b"0123456789ABCDEFGHIJ").await;
249 assert_eq!(stream.read().await, b"ABCDEFGHIJ");
250 }
251
252 #[tokio::test]
253 async fn test_close_prevents_writes() {
254 let stream = BoundedStream::new(100);
255 stream.write(b"before").await;
256 stream.close().await;
257 stream.write(b"after").await;
258 assert_eq!(stream.read().await, b"before");
259 }
260
261 #[tokio::test]
262 async fn test_read_string() {
263 let stream = BoundedStream::new(100);
264 stream.write(b"hello world").await;
265 assert_eq!(stream.read_string().await, "hello world");
266 }
267
268 #[tokio::test]
269 async fn test_concurrent_writes() {
270 use std::sync::Arc;
271
272 let stream = Arc::new(BoundedStream::new(1000));
273
274 let handles: Vec<_> = (0..10)
275 .map(|i| {
276 let s = stream.clone();
277 tokio::spawn(async move {
278 for j in 0..10 {
279 s.write(format!("[{}-{}]", i, j).as_bytes()).await;
280 }
281 })
282 })
283 .collect();
284
285 for h in handles {
286 h.await.expect("task should not panic");
287 }
288
289 let data = stream.read().await;
292 assert!(!data.is_empty());
293 }
294
295 #[tokio::test]
296 async fn test_stats() {
297 let stream = BoundedStream::new(10);
298 stream.write(b"1234567890").await;
299
300 let stats = stream.stats().await;
301 assert_eq!(stats.current_size, 10);
302 assert_eq!(stats.max_size, 10);
303 assert_eq!(stats.total_written, 10);
304 assert_eq!(stats.bytes_evicted, 0);
305 assert!(!stats.closed);
306 }
307
308 #[tokio::test]
309 async fn test_empty_stream() {
310 let stream = BoundedStream::new(100);
311 assert!(stream.is_empty().await);
312 assert_eq!(stream.len().await, 0);
313 assert_eq!(stream.read().await, Vec::<u8>::new());
314 }
315
316 #[tokio::test]
317 async fn test_drain_to_stream() {
318 use std::io::Cursor;
319
320 let data = b"test data from reader";
321 let cursor = Cursor::new(data.to_vec());
322 let stream = Arc::new(BoundedStream::new(100));
323
324 drain_to_stream(cursor, stream.clone()).await;
325
326 assert_eq!(stream.read().await, data);
327 assert!(stream.is_closed().await);
328 }
329
330 #[tokio::test]
331 async fn test_default_size() {
332 let stream = BoundedStream::default_size();
333 let stats = stream.stats().await;
334 assert_eq!(stats.max_size, DEFAULT_STREAM_MAX_SIZE);
335 }
336}