1use ankurah_proto::{self as proto, CollectionId};
2use anyhow::anyhow;
3use dashmap::{DashMap, DashSet};
4use rand::prelude::*;
5use std::{
6 collections::{btree_map::Entry, BTreeMap, BTreeSet},
7 ops::Deref,
8 sync::{Arc, Weak},
9};
10use tokio::sync::{oneshot, RwLock};
11
12use crate::{
13 changes::{ChangeSet, EntityChange, ItemChange},
14 connector::PeerSender,
15 context::Context,
16 error::{RequestError, RetrievalError},
17 model::Entity,
18 policy::PolicyAgent,
19 reactor::Reactor,
20 storage::{StorageCollectionWrapper, StorageEngine},
21 subscription::SubscriptionHandle,
22 task::spawn,
23};
24use tracing::{debug, info, warn};
25
26pub struct PeerState {
27 sender: Box<dyn PeerSender>,
28 _durable: bool,
29 subscriptions: BTreeSet<proto::SubscriptionId>,
30}
31
32pub struct MatchArgs {
33 pub predicate: ankql::ast::Predicate,
34 pub cached: bool,
35}
36
37impl TryInto<MatchArgs> for &str {
38 type Error = ankql::error::ParseError;
39 fn try_into(self) -> Result<MatchArgs, Self::Error> {
40 Ok(MatchArgs { predicate: ankql::parser::parse_selection(self)?, cached: false })
41 }
42}
43impl TryInto<MatchArgs> for String {
44 type Error = ankql::error::ParseError;
45 fn try_into(self) -> Result<MatchArgs, Self::Error> {
46 Ok(MatchArgs { predicate: ankql::parser::parse_selection(&self)?, cached: false })
47 }
48}
49
50impl Into<MatchArgs> for ankql::ast::Predicate {
51 fn into(self) -> MatchArgs { MatchArgs { predicate: self, cached: false } }
52}
53
54impl From<ankql::error::ParseError> for RetrievalError {
55 fn from(e: ankql::error::ParseError) -> Self { RetrievalError::ParseError(e) }
56}
57
58pub struct Node<SE, PA>(Arc<NodeInner<SE, PA>>);
61impl<SE, PA> Clone for Node<SE, PA> {
62 fn clone(&self) -> Self { Self(self.0.clone()) }
63}
64
65pub struct WeakNode<SE, PA>(Weak<NodeInner<SE, PA>>);
66impl<SE, PA> Clone for WeakNode<SE, PA> {
67 fn clone(&self) -> Self { Self(self.0.clone()) }
68}
69
70impl<SE, PA> Deref for Node<SE, PA> {
71 type Target = Arc<NodeInner<SE, PA>>;
72 fn deref(&self) -> &Self::Target { &self.0 }
73}
74
75pub trait ContextData: Clone + Send + Sync + 'static {}
79
80pub struct NodeInner<SE, PA> {
81 pub id: proto::NodeId,
82 pub durable: bool,
83 storage_engine: Arc<SE>,
84 collections: RwLock<BTreeMap<CollectionId, StorageCollectionWrapper>>,
85
86 entities: Arc<RwLock<EntityMap>>,
87 peer_connections: DashMap<proto::NodeId, PeerState>,
89 durable_peers: DashSet<proto::NodeId>,
90 pending_requests: DashMap<proto::RequestId, oneshot::Sender<Result<proto::NodeResponseBody, RequestError>>>,
91
92 pub reactor: Arc<Reactor<SE, PA>>,
94 _policy_agent: PA,
95}
96
97type EntityMap = BTreeMap<(proto::ID, proto::CollectionId), Weak<Entity>>;
98
99impl<SE, PA> Node<SE, PA>
100where
101 SE: StorageEngine + Send + Sync + 'static,
102 PA: PolicyAgent + Send + Sync + 'static,
103{
104 pub fn new(engine: Arc<SE>, policy_agent: PA) -> Self {
105 let reactor = Reactor::new(engine.clone(), policy_agent.clone());
106 let id = proto::NodeId::new();
107 info!("Node {} created", id);
108 let node = Node(Arc::new(NodeInner {
109 id,
110 storage_engine: engine,
111 collections: RwLock::new(BTreeMap::new()),
112 entities: Arc::new(RwLock::new(BTreeMap::new())),
113 peer_connections: DashMap::new(),
114 durable_peers: DashSet::new(),
115 pending_requests: DashMap::new(),
116 reactor,
117 durable: false,
118 _policy_agent: policy_agent,
119 }));
120 node
123 }
124 pub fn new_durable(engine: Arc<SE>, policy_agent: PA) -> Self {
125 let reactor = Reactor::new(engine.clone(), policy_agent.clone());
126
127 let node = Node(Arc::new(NodeInner {
128 id: proto::NodeId::new(),
129 storage_engine: engine,
130 collections: RwLock::new(BTreeMap::new()),
131 entities: Arc::new(RwLock::new(BTreeMap::new())),
132 peer_connections: DashMap::new(),
133 durable_peers: DashSet::new(),
134 pending_requests: DashMap::new(),
135 reactor,
136 durable: true,
137 _policy_agent: policy_agent,
138 }));
139 node
141 }
142 pub fn weak(&self) -> WeakNode<SE, PA> { WeakNode(Arc::downgrade(&self.0)) }
143}
144
145impl<SE, PA> WeakNode<SE, PA> {
146 pub fn upgrade(&self) -> Option<Node<SE, PA>> { self.0.upgrade().map(Node) }
147}
148
149impl<SE, PA> NodeInner<SE, PA>
150where
151 SE: StorageEngine + Send + Sync + 'static,
152 PA: PolicyAgent + Send + Sync + 'static,
153{
154 pub fn register_peer(&self, presence: proto::Presence, sender: Box<dyn PeerSender>) {
155 info!("Node {} register peer {}", self.id, presence.node_id);
156 self.peer_connections
157 .insert(presence.node_id.clone(), PeerState { sender, _durable: presence.durable, subscriptions: BTreeSet::new() });
158 if presence.durable {
159 self.durable_peers.insert(presence.node_id.clone());
160 }
161 }
163 pub fn deregister_peer(&self, node_id: proto::NodeId) {
164 info!("Node {} deregister peer {}", self.id, node_id);
165 self.peer_connections.remove(&node_id);
166 self.durable_peers.remove(&node_id);
167 }
168 pub async fn request(
169 &self,
170 node_id: proto::NodeId,
171 request_body: proto::NodeRequestBody,
172 ) -> Result<proto::NodeResponseBody, RequestError> {
173 let (response_tx, response_rx) = oneshot::channel::<Result<proto::NodeResponseBody, RequestError>>();
174 let request_id = proto::RequestId::new();
175
176 self.pending_requests.insert(request_id.clone(), response_tx);
178
179 let request = proto::NodeRequest { id: request_id, to: node_id.clone(), from: self.id.clone(), body: request_body };
180
181 {
182 let connection = { self.peer_connections.get(&node_id).ok_or(RequestError::PeerNotConnected)?.sender.cloned() };
185
186 connection.send_message(proto::NodeMessage::Request(request)).await?;
188 }
189
190 response_rx.await.map_err(|_| RequestError::InternalChannelClosed)?
192 }
193
194 pub async fn handle_message(self: &Arc<Self>, message: proto::NodeMessage) -> anyhow::Result<()> {
195 match message {
196 proto::NodeMessage::Request(request) => {
197 info!("Node {} received request {}", self.id, request);
198 if let Some(sender) = { self.peer_connections.get(&request.from).map(|c| c.sender.cloned()) } {
205 let from = request.from.clone();
206 let request_id = request.id.clone();
207 if request.to != self.id {
208 warn!("{} received message from {} but is not the intended recipient", self.id, request.from);
209 }
210
211 let body = match self.handle_request(request).await {
212 Ok(result) => result,
213 Err(e) => proto::NodeResponseBody::Error(e.to_string()),
214 };
215 let _result = sender
216 .send_message(proto::NodeMessage::Response(proto::NodeResponse {
217 request_id,
218 from: self.id.clone(),
219 to: from,
220 body,
221 }))
222 .await;
223 }
224 }
225 proto::NodeMessage::Response(response) => {
226 info!("Node {} received response {}", self.id, response);
227 if let Some((_, tx)) = self.pending_requests.remove(&response.request_id) {
228 tx.send(Ok(response.body)).map_err(|e| anyhow!("Failed to send response: {:?}", e))?;
229 }
230 }
231 }
232 Ok(())
233 }
234
235 async fn handle_request(self: &Arc<Self>, request: proto::NodeRequest) -> anyhow::Result<proto::NodeResponseBody> {
236 match request.body {
237 proto::NodeRequestBody::CommitEvents(events) => {
238 match self.commit_events_local(&events).await {
242 Ok(_) => Ok(proto::NodeResponseBody::CommitComplete),
243 Err(e) => Ok(proto::NodeResponseBody::Error(e.to_string())),
244 }
245 }
246 proto::NodeRequestBody::Fetch { collection, predicate } => {
247 let storage_collection = self.collection(&collection).await;
248 let states: Vec<_> = storage_collection.fetch_states(&predicate).await?.into_iter().collect();
249 Ok(proto::NodeResponseBody::Fetch(states))
250 }
251 proto::NodeRequestBody::Subscribe { subscription_id, collection, predicate } => {
252 self.handle_subscribe_request(request.from, subscription_id, collection, predicate).await
253 }
254 proto::NodeRequestBody::Unsubscribe { subscription_id } => {
255 self.reactor.unsubscribe(subscription_id);
256 if let Some(mut peer_state) = self.peer_connections.get_mut(&request.from) {
258 peer_state.subscriptions.remove(&subscription_id);
259 }
260 Ok(proto::NodeResponseBody::Success)
261 }
262 }
263 }
264
265 pub async fn request_remote_subscribe(
266 &self,
267 sub: &mut SubscriptionHandle,
268 collection_id: &CollectionId,
269 predicate: &ankql::ast::Predicate,
270 ) -> anyhow::Result<()> {
271 let durable_peer_id = self.get_durable_peer_random();
273
274 if let Some(peer_id) = durable_peer_id {
276 match self
277 .request(
278 peer_id,
279 proto::NodeRequestBody::Subscribe {
280 subscription_id: sub.id.clone(),
281 collection: collection_id.clone(),
282 predicate: predicate.clone(),
283 },
284 )
285 .await?
286 {
287 proto::NodeResponseBody::Subscribe { initial, subscription_id: _ } => {
288 let raw_bucket = self.collection(&collection_id).await;
290 for (id, state) in initial {
291 raw_bucket.set_state(id, &state).await.map_err(|e| anyhow!("Failed to set entity: {:?}", e))?;
292 }
293 }
294 proto::NodeResponseBody::Error(e) => {
295 return Err(anyhow!("Error from peer subscription: {}", e));
296 }
297 _ => {
298 return Err(anyhow!("Unexpected response type from peer subscription"));
299 }
300 }
301 }
302 Ok(())
303 }
304 pub async fn request_remote_unsubscribe(&self, sub_id: proto::SubscriptionId, peers: Vec<proto::NodeId>) -> anyhow::Result<()> {
305 futures::future::join_all(
308 peers
309 .iter()
310 .map(|peer_id| self.request(peer_id.clone(), proto::NodeRequestBody::Unsubscribe { subscription_id: sub_id.clone() })),
311 )
312 .await
313 .into_iter()
314 .collect::<Result<Vec<_>, _>>()?;
315
316 Ok(())
317 }
318
319 async fn handle_subscribe_request(
320 self: &Arc<Self>,
321 peer_id: proto::NodeId,
322 sub_id: proto::SubscriptionId,
323 collection_id: CollectionId,
324 predicate: ankql::ast::Predicate,
325 ) -> anyhow::Result<proto::NodeResponseBody> {
326 let storage_collection = self.collection(&collection_id).await;
328 let states = storage_collection.fetch_states(&predicate).await?;
329
330 let node = self.clone();
332 {
333 let peer_id = peer_id.clone();
334 self.reactor
335 .subscribe(sub_id, &collection_id, predicate, move |changeset| {
336 let events: Vec<_> = changeset
338 .changes
339 .iter()
340 .flat_map(|change| match change {
341 ItemChange::Add { events: updates, .. }
342 | ItemChange::Update { events: updates, .. }
343 | ItemChange::Remove { events: updates, .. } => &updates[..],
344 ItemChange::Initial { .. } => &[],
345 })
346 .cloned()
347 .collect();
348
349 if !events.is_empty() {
350 let node = node.clone();
351 let peer_id = peer_id.clone();
352 tokio::spawn(async move {
353 let _ = node.request(peer_id, proto::NodeRequestBody::CommitEvents(events)).await;
354 });
355 }
356 })
357 .await?;
358 };
359
360 if let Some(mut peer_state) = self.peer_connections.get_mut(&peer_id) {
362 peer_state.subscriptions.insert(sub_id);
363 }
364
365 Ok(proto::NodeResponseBody::Subscribe { initial: states, subscription_id: sub_id })
366 }
367
368 pub async fn collection(&self, id: &CollectionId) -> StorageCollectionWrapper {
369 let collections = self.collections.read().await;
370 if let Some(store) = collections.get(id) {
371 return store.clone();
372 }
373 drop(collections);
374
375 let collection = StorageCollectionWrapper::new(self.storage_engine.collection(id).await.unwrap());
376
377 let mut collections = self.collections.write().await;
378
379 if let Entry::Vacant(entry) = collections.entry(id.clone()) {
381 entry.insert(collection.clone());
382 }
383 drop(collections);
384
385 collection
386 }
387
388 pub fn next_entity_id(&self) -> proto::ID { proto::ID::new() }
389
390 pub fn context(self: &Arc<Self>, data: PA::ContextData) -> Context { Context::new(Node(self.clone()), data) }
391
392 async fn commit_events_local(self: &Arc<Self>, events: &Vec<proto::Event>) -> anyhow::Result<()> {
393 info!("Node {} committing events {}", self.id, events.iter().map(|e| e.to_string()).collect::<Vec<_>>().join(","));
394 let mut changes = Vec::new();
395
396 for event in events {
398 let entity = self.get_entity(&event.collection, event.entity_id).await?;
400
401 entity.apply_event(event)?;
402
403 let state = entity.to_state()?;
404 let collection = self.collection(&event.collection).await;
406 collection.add_event(&event).await?;
407 let changed = collection.set_state(event.entity_id, &state).await?;
408
409 if changed {
410 changes.push(EntityChange { entity: entity.clone(), events: vec![event.clone()] });
411 }
412 }
413 self.reactor.notify_change(changes);
414
415 Ok(())
416 }
417
418 pub async fn commit_events(self: &Arc<Self>, events: &Vec<proto::Event>) -> anyhow::Result<()> {
420 self.commit_events_local(events).await?;
421
422 let peer_ids: Vec<_> = self.peer_connections.iter().map(|i| i.key().clone()).collect();
424
425 futures::future::join_all(peer_ids.iter().map(|peer_id| {
426 let events = events.clone();
427 async move {
428 match self.request(peer_id.clone(), proto::NodeRequestBody::CommitEvents(events)).await {
429 Ok(proto::NodeResponseBody::CommitComplete) => {
430 info!("Peer {} confirmed commit", peer_id)
431 }
432 Ok(proto::NodeResponseBody::Error(e)) => warn!("Peer {} error: {}", peer_id, e),
433 Ok(_) => warn!("Peer {} unexpected response type", peer_id),
434 Err(_) => warn!("Peer {} internal channel closed", peer_id),
435 }
436 }
437 }))
438 .await;
439
440 Ok(())
441 }
442
443 pub(crate) async fn insert_entity(self: &Arc<Self>, entity: Arc<Entity>) -> anyhow::Result<()> {
449 match self.entities.write().await.entry((entity.id, entity.collection.clone())) {
450 Entry::Vacant(entry) => {
451 entry.insert(Arc::downgrade(&entity));
452 Ok(())
453 }
454 Entry::Occupied(_) => Err(anyhow!("Entity already exists")),
455 }
456 }
457
458 #[must_use]
462 pub(crate) async fn assert_entity(
463 &self,
464 collection_id: &CollectionId,
465 id: proto::ID,
466 state: &proto::State,
467 ) -> Result<Arc<Entity>, RetrievalError> {
468 let mut entities = self.entities.write().await;
469
470 match entities.entry((id, collection_id.clone())) {
471 Entry::Occupied(mut entry) => {
472 if let Some(entity) = entry.get().upgrade() {
473 entity.apply_state(state)?;
474 Ok(entity)
475 } else {
476 let entity = Arc::new(Entity::from_state(id, collection_id.clone(), state)?);
477 entry.insert(Arc::downgrade(&entity));
478 Ok(entity)
479 }
480 }
481 Entry::Vacant(entry) => {
482 let entity = Arc::new(Entity::from_state(id, collection_id.clone(), state)?);
483 entry.insert(Arc::downgrade(&entity));
484 Ok(entity)
485 }
486 }
487 }
488
489 pub(crate) async fn fetch_entity_from_node(
490 &self,
491 id: proto::ID,
492 collection_id: &CollectionId,
493 ) -> Option<Arc<Entity>> {
495 let entities = self.entities.read().await;
496 if let Some(entity) = entities.get(&(id, collection_id.clone())) {
498 entity.upgrade()
499 } else {
500 None
501 }
502 }
503
504 pub(crate) async fn get_entity(
506 &self,
507 collection_id: &CollectionId,
508 id: proto::ID,
509 ) -> Result<Arc<Entity>, RetrievalError> {
511 info!("fetch_entity {:?}-{:?}", id, collection_id);
512
513 if let Some(local) = self.fetch_entity_from_node(id, collection_id).await {
514 return Ok(local);
515 }
516 debug!("fetch_entity 2");
517
518 let collection = self.collection(collection_id).await;
519 match collection.get_state(id).await {
520 Ok(entity_state) => {
521 return self.assert_entity(collection_id, id, &entity_state).await;
522 }
523 Err(RetrievalError::NotFound(id)) => {
524 let entity = self.assert_entity(collection_id, id, &proto::State::default()).await?;
528 Ok(entity)
529 }
530 Err(e) => Err(e),
531 }
532 }
533
534 pub async fn fetch_entities(
536 self: &Arc<Self>,
537 collection_id: &CollectionId,
538 args: MatchArgs,
539 _cdata: &PA::ContextData,
540 ) -> Result<Vec<Arc<Entity>>, RetrievalError> {
541 if !self.durable {
542 match self.fetch_from_peer(&collection_id, &args.predicate).await {
544 Ok(_) => (),
545 Err(RetrievalError::NoDurablePeers) if args.cached => (),
546 Err(e) => {
547 return Err(e.into());
548 }
549 }
550 }
551
552 let storage_collection = self.collection(&collection_id).await;
554 let states = storage_collection.fetch_states(&args.predicate).await?;
555
556 let mut entities = Vec::new();
558 for (id, state) in states {
559 let entity = self.assert_entity(&collection_id, id, &state).await?;
560 entities.push(entity);
561 }
562 Ok(entities)
563 }
564
565 pub async fn subscribe(
566 self: &Arc<Self>,
567 sub_id: proto::SubscriptionId,
568 collection_id: &CollectionId,
569 args: MatchArgs,
570 callback: Box<dyn Fn(ChangeSet<Arc<Entity>>) + Send + Sync + 'static>,
571 ) -> Result<SubscriptionHandle, RetrievalError> {
572 let mut handle = SubscriptionHandle::new(Box::new(Node(self.clone())) as Box<dyn TNodeErased>, sub_id);
573
574 self.request_remote_subscribe(&mut handle, &collection_id, &args.predicate).await?;
577 self.reactor.subscribe(handle.id, &collection_id, args, callback).await?;
578 Ok(handle)
581 }
582 pub fn unsubscribe(self: &Arc<Self>, handle: &SubscriptionHandle) -> anyhow::Result<()> {
583 let node = self.clone();
584 let peers = handle.peers.clone();
585 let sub_id = handle.id.clone();
586 spawn(async move {
587 node.reactor.unsubscribe(sub_id);
588 if let Err(e) = node.request_remote_unsubscribe(sub_id, peers).await {
589 warn!("Error unsubscribing from peers: {}", e);
590 }
591 });
592 Ok(())
593 }
594 async fn fetch_from_peer(
596 self: &Arc<Self>,
597 collection_id: &CollectionId,
598 predicate: &ankql::ast::Predicate,
599 ) -> anyhow::Result<(), RetrievalError> {
600 let peer_id = self.get_durable_peer_random().ok_or(RetrievalError::NoDurablePeers)?;
601
602 match self
603 .request(peer_id.clone(), proto::NodeRequestBody::Fetch { collection: collection_id.clone(), predicate: predicate.clone() })
604 .await
605 .map_err(|e| RetrievalError::Other(format!("{:?}", e)))?
606 {
607 proto::NodeResponseBody::Fetch(states) => {
608 let raw_bucket = self.collection(collection_id).await;
609 for (id, state) in states {
612 raw_bucket.set_state(id, &state).await.map_err(|e| RetrievalError::Other(format!("{:?}", e)))?;
613 }
614 Ok(())
615 }
616 proto::NodeResponseBody::Error(e) => {
617 debug!("Error from peer fetch: {}", e);
618 Err(RetrievalError::Other(format!("{:?}", e)))
619 }
620 _ => {
621 debug!("Unexpected response type from peer fetch");
622 Err(RetrievalError::Other("Unexpected response type".to_string()))
623 }
624 }
625 }
626
627 pub fn get_durable_peer_random(&self) -> Option<proto::NodeId> {
629 let mut rng = rand::thread_rng();
630 let peers: Vec<_> = self.durable_peers.iter().collect();
632 peers.choose(&mut rng).map(|i| i.key().clone())
633 }
634
635 pub fn get_durable_peers(&self) -> Vec<proto::NodeId> { self.durable_peers.iter().map(|id| id.clone()).collect() }
637}
638
639impl<SE, PA> Drop for Node<SE, PA> {
640 fn drop(&mut self) {
641 info!("Node {} dropped", self.id);
642 }
643}
644
645pub trait TNodeErased: Send + Sync + 'static {
646 fn unsubscribe(&self, handle: &SubscriptionHandle) -> ();
647}
648
649impl<SE, PA> TNodeErased for Node<SE, PA>
650where
651 SE: StorageEngine + Send + Sync + 'static,
652 PA: PolicyAgent + Send + Sync + 'static,
653{
654 fn unsubscribe(&self, handle: &SubscriptionHandle) -> () { let _ = self.0.unsubscribe(handle); }
655}