Skip to main content

co_primitives/library/
node_stream.rs

1// SPDX-License-Identifier: AGPL-3.0-only
2// Copyright (C) 2026 1io BRANDGUARDIAN GmbH
3
4use super::node_builder::NodeReader;
5use crate::{BlockStorage, BlockStorageExt, Node, OptionLink, StorageError};
6use cid::Cid;
7use either::Either;
8use futures::{Future, FutureExt, Stream};
9use pin_project::pin_project;
10use serde::de::DeserializeOwned;
11use std::{
12	collections::VecDeque,
13	pin::Pin,
14	task::{Context, Poll},
15};
16
17/// Stream node items.
18#[pin_project]
19pub struct NodeStream<S, T, N = Node<T>>
20where
21	N: NodeReader<T>,
22{
23	storage: S,
24	stack: VecDeque<Cid>,
25	entries: VecDeque<T>,
26	#[pin]
27	get: Option<Pin<Box<dyn Future<Output = Result<N, StorageError>> + Send>>>,
28	filter: N::Filter,
29	reverse: bool,
30}
31impl<S, T, N> NodeStream<S, T, N>
32where
33	S: BlockStorage + Clone + 'static,
34	T: DeserializeOwned + Send + Sync + 'static,
35	N: NodeReader<T> + DeserializeOwned + Send + Sync + 'static,
36{
37	pub fn new(storage: S, cid: Option<Cid>) -> Self {
38		let mut stack = VecDeque::new();
39		if let Some(cid) = cid {
40			stack.push_front(cid);
41		}
42		Self { storage, stack, entries: Default::default(), get: None, filter: Default::default(), reverse: false }
43	}
44
45	pub fn from_link(storage: S, link: OptionLink<N>) -> Self {
46		Self::new(storage, *link.cid())
47	}
48
49	pub fn from_node(storage: S, node: N, filter: Option<N::Filter>) -> Self {
50		let filter = filter.unwrap_or_default();
51		let (stack, entries) = match node.read(&filter) {
52			Either::Left(stack) => (stack.into_iter().collect(), Default::default()),
53			Either::Right(entries) => (Default::default(), entries.into_iter().collect()),
54		};
55		Self { storage, stack, entries, get: None, filter, reverse: false }
56	}
57
58	/// Iterate with filter.
59	pub fn with_filter(mut self, filter: N::Filter) -> Self {
60		self.filter = filter;
61		self
62	}
63
64	/// Iterate in reverse order.
65	pub fn with_reverse(mut self) -> Self {
66		self.reverse = true;
67		self
68	}
69}
70impl<S, T, N> Stream for NodeStream<S, T, N>
71where
72	S: BlockStorage + Clone + 'static,
73	T: DeserializeOwned + Send + Sync + 'static,
74	N: NodeReader<T> + DeserializeOwned + Send + Sync + 'static,
75{
76	type Item = Result<T, StorageError>;
77
78	fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
79		loop {
80			// get next?
81			if self.entries.is_empty() && !self.stack.is_empty() && self.get.is_none() {
82				if let Some(next_cid) = if self.reverse { self.stack.pop_back() } else { self.stack.pop_front() } {
83					let storage = self.storage.clone();
84					self.get = Some(Box::pin(async move { storage.get_deserialized::<N>(&next_cid).await }));
85				}
86			}
87
88			// waiting?
89			if let Some(mut get) = Pin::new(&mut self).get.take() {
90				match get.poll_unpin(cx) {
91					Poll::Ready(Ok(node)) => match node.read(&self.filter) {
92						Either::Left(links) => {
93							self.stack.extend(links);
94							continue;
95						},
96						Either::Right(entries) => {
97							self.entries = entries.into();
98						},
99					},
100					Poll::Ready(Err(e)) => {
101						// clear
102						self.stack.clear();
103						self.entries.clear();
104
105						// fail
106						return Poll::Ready(Some(Err(e)));
107					},
108					Poll::Pending => {
109						self.get = Some(get);
110						return Poll::Pending;
111					},
112				}
113			}
114			break;
115		}
116
117		// read entry
118		Poll::Ready(
119			if self.reverse { self.entries.pop_back() } else { self.entries.pop_front() }.map(|entry| Ok(entry)),
120		)
121	}
122}
123
124#[cfg(test)]
125mod tests {
126	use crate::{library::test::TestStorage, BlockStorage, DefaultNodeSerializer, NodeBuilder, NodeStream};
127	use futures::TryStreamExt;
128
129	#[tokio::test]
130	async fn test_stream() {
131		let storage = TestStorage::default();
132
133		// build
134		let mut builder = NodeBuilder::new(storage.max_block_size(), 2, DefaultNodeSerializer::new());
135		for i in 0..10 {
136			builder.push(i).unwrap();
137		}
138		let (root, blocks) = builder.into_blocks().unwrap();
139		for block in blocks {
140			storage.set(block).await.unwrap();
141		}
142
143		// stream
144		let list = NodeStream::from_link(storage.clone(), root)
145			.try_collect::<Vec<i32>>()
146			.await
147			.unwrap();
148		assert_eq!(list[..], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]);
149	}
150
151	#[tokio::test]
152	async fn test_stream_reverse() {
153		let storage = TestStorage::default();
154
155		// build
156		let mut builder = NodeBuilder::new(storage.max_block_size(), 2, DefaultNodeSerializer::new());
157		for i in 0..10 {
158			builder.push(i).unwrap();
159		}
160		let (root, blocks) = builder.into_blocks().unwrap();
161		for block in blocks {
162			storage.set(block).await.unwrap();
163		}
164
165		// stream
166		let list = NodeStream::from_link(storage.clone(), root)
167			.with_reverse()
168			.try_collect::<Vec<i32>>()
169			.await
170			.unwrap();
171		assert_eq!(list[..], [9, 8, 7, 6, 5, 4, 3, 2, 1, 0]);
172	}
173}