ng_net/
actor.rs

1/*
2 * Copyright (c) 2022-2025 Niko Bonnieure, Par le Peuple, NextGraph.org developers
3 * All rights reserved.
4 * Licensed under the Apache License, Version 2.0
5 * <LICENSE-APACHE2 or http://www.apache.org/licenses/LICENSE-2.0>
6 * or the MIT license <LICENSE-MIT or http://opensource.org/licenses/MIT>,
7 * at your option. All files in the project carrying such
8 * notice may not be copied, modified, or distributed except
9 * according to those terms.
10*/
11
12//! Actor handles messages in the Protocol. common types are here
13
14use std::any::TypeId;
15use std::marker::PhantomData;
16use std::sync::Arc;
17
18use async_std::stream::StreamExt;
19use async_std::sync::Mutex;
20use futures::{channel::mpsc, SinkExt};
21
22use ng_repo::errors::{NgError, ProtocolError, ServerError};
23use ng_repo::log::*;
24
25use crate::utils::{spawn_and_log_error, Receiver, ResultSend, Sender};
26use crate::{connection::*, types::ProtocolMessage};
27
28impl TryFrom<ProtocolMessage> for () {
29    type Error = ProtocolError;
30    fn try_from(_msg: ProtocolMessage) -> Result<Self, Self::Error> {
31        Ok(())
32    }
33}
34
35#[doc(hidden)]
36#[async_trait::async_trait]
37pub trait EActor: Send + Sync + std::fmt::Debug {
38    async fn respond(
39        &mut self,
40        msg: ProtocolMessage,
41        fsm: Arc<Mutex<NoiseFSM>>,
42    ) -> Result<(), ProtocolError>;
43
44    fn set_id(&mut self, _id: i64) {}
45}
46
47#[derive(Debug)]
48pub(crate) struct Actor<
49    'a,
50    A: Into<ProtocolMessage> + std::fmt::Debug,
51    B: TryFrom<ProtocolMessage, Error = ProtocolError> + std::fmt::Debug + Sync,
52> {
53    id: i64,
54    phantom_a: PhantomData<&'a A>,
55    phantom_b: PhantomData<&'a B>,
56    receiver: Option<Receiver<ConnectionCommand>>,
57    receiver_tx: Sender<ConnectionCommand>,
58    //initiator: bool,
59}
60
61#[derive(Debug)]
62pub enum SoS<B> {
63    Single(B),
64    Stream(Receiver<B>),
65}
66
67impl<B> SoS<B> {
68    pub fn is_single(&self) -> bool {
69        if let Self::Single(_b) = self {
70            true
71        } else {
72            false
73        }
74    }
75    pub fn is_stream(&self) -> bool {
76        !self.is_single()
77    }
78    pub fn unwrap_single(self) -> B {
79        match self {
80            Self::Single(s) => s,
81            Self::Stream(_s) => {
82                panic!("called `unwrap_single()` on a `Stream` value")
83            }
84        }
85    }
86    pub fn unwrap_stream(self) -> Receiver<B> {
87        match self {
88            Self::Stream(s) => s,
89            Self::Single(_s) => {
90                panic!("called `unwrap_stream()` on a `Single` value")
91            }
92        }
93    }
94}
95
96impl<
97        A: Into<ProtocolMessage> + std::fmt::Debug + 'static,
98        B: TryFrom<ProtocolMessage, Error = ProtocolError> + Sync + Send + std::fmt::Debug + 'static,
99    > Actor<'_, A, B>
100{
101    pub fn new(id: i64, _initiator: bool) -> Self {
102        let (receiver_tx, receiver) = mpsc::unbounded::<ConnectionCommand>();
103        Self {
104            id,
105            receiver: Some(receiver),
106            receiver_tx,
107            phantom_a: PhantomData,
108            phantom_b: PhantomData,
109            //initiator,
110        }
111    }
112
113    // pub fn verify(&self, msg: ProtocolMessage) -> bool {
114    //     self.initiator && msg.type_id() == TypeId::of::<B>()
115    //         || !self.initiator && msg.type_id() == TypeId::of::<A>()
116    // }
117
118    pub fn detach_receiver(&mut self) -> Receiver<ConnectionCommand> {
119        self.receiver.take().unwrap()
120    }
121
122    pub async fn request(
123        &mut self,
124        msg: ProtocolMessage,
125        fsm: Arc<Mutex<NoiseFSM>>,
126    ) -> Result<SoS<B>, NgError> {
127        fsm.lock().await.send(msg).await?;
128        let mut receiver = self.receiver.take().unwrap();
129        match receiver.next().await {
130            Some(ConnectionCommand::Msg(msg)) => {
131                if let Some(bm) = msg.is_streamable() {
132                    if bm.result() == Into::<u16>::into(ServerError::PartialContent)
133                        && TypeId::of::<B>() != TypeId::of::<()>()
134                    {
135                        let (mut b_sender, b_receiver) = mpsc::unbounded::<B>();
136                        let response = msg.try_into().map_err(|e| {
137                            log_err!("msg.try_into {}", e);
138                            ProtocolError::ActorError
139                        })?;
140                        b_sender
141                            .send(response)
142                            .await
143                            .map_err(|_err| ProtocolError::IoError)?;
144                        async fn pump_stream<C: TryFrom<ProtocolMessage, Error = ProtocolError>>(
145                            mut actor_receiver: Receiver<ConnectionCommand>,
146                            mut sos_sender: Sender<C>,
147                            fsm: Arc<Mutex<NoiseFSM>>,
148                            id: i64,
149                        ) -> ResultSend<()> {
150                            async move {
151                                while let Some(ConnectionCommand::Msg(msg)) =
152                                    actor_receiver.next().await
153                                {
154                                    if let Some(bm) = msg.is_streamable() {
155                                        if bm.result()
156                                            == Into::<u16>::into(ServerError::EndOfStream)
157                                        {
158                                            break;
159                                        }
160                                        let response = msg.try_into();
161                                        if response.is_err() {
162                                            // TODO deal with errors.
163                                            break;
164                                        }
165                                        if sos_sender.send(response.unwrap()).await.is_err() {
166                                            break;
167                                        }
168                                    } else {
169                                        // todo deal with error (not a ClientMessage)
170                                        break;
171                                    }
172                                }
173                                fsm.lock().await.remove_actor(id).await;
174                            }
175                            .await;
176                            Ok(())
177                        }
178                        spawn_and_log_error(pump_stream::<B>(
179                            receiver,
180                            b_sender,
181                            Arc::clone(&fsm),
182                            self.id,
183                        ));
184                        return Ok(SoS::<B>::Stream(b_receiver));
185                    }
186                }
187                fsm.lock().await.remove_actor(self.id).await;
188                let server_error: Result<ServerError, NgError> = (&msg).try_into();
189                //log_debug!("server_error {:?}", server_error);
190                if server_error.is_ok() {
191                    return Err(NgError::ServerError(server_error.unwrap()));
192                }
193                let response: B = match msg.try_into() {
194                    Ok(b) => b,
195                    Err(ProtocolError::ServerError) => {
196                        return Err(NgError::ServerError(server_error?));
197                    }
198                    Err(e) => return Err(NgError::ProtocolError(e)),
199                };
200                Ok(SoS::<B>::Single(response))
201            }
202            Some(ConnectionCommand::ProtocolError(e)) => Err(e.into()),
203            Some(ConnectionCommand::Error(e)) => Err(ProtocolError::from(e).into()),
204            Some(ConnectionCommand::Close) => Err(ProtocolError::Closing.into()),
205            _ => Err(ProtocolError::ActorError.into()),
206        }
207    }
208
209    pub fn new_responder(id: i64) -> Box<Self> {
210        Box::new(Self::new(id, false))
211    }
212
213    pub fn get_receiver_tx(&self) -> Sender<ConnectionCommand> {
214        self.receiver_tx.clone()
215    }
216
217    pub fn id(&self) -> i64 {
218        self.id
219    }
220}
221
222#[cfg(test)]
223mod test {
224
225    use crate::actor::*;
226    use crate::actors::*;
227
228    #[async_std::test]
229    pub async fn test_actor() {
230        let _a = Actor::<Noise, Noise>::new(1, true);
231        // a.handle(ProtocolMessage::Start(StartProtocol::Client(
232        //     ClientHello::Noise3(Noise::V0(NoiseV0 { data: vec![] })),
233        // )))
234        // .await;
235        // a.handle(ProtocolMessage::Noise(Noise::V0(NoiseV0 { data: vec![] })))
236        //     .await;
237    }
238}