Skip to main content

flow_bot/base/
context.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::Arc,
5};
6
7use async_trait::async_trait;
8use dashmap::DashMap;
9use futures::{SinkExt, stream::SplitSink};
10use serde_json::json;
11use tokio::{
12    net::TcpStream,
13    sync::{Mutex, oneshot},
14};
15use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
16
17use crate::{
18    api::{ApiResponse, api_ext::ApiExt},
19    error::FlowError,
20    event::BotEvent,
21};
22
23use super::extract::FromEvent;
24
25pub struct Context {
26    pub(crate) sink: Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
27    pending_requests: Arc<DashMap<String, oneshot::Sender<String>>>,
28    pub(crate) state: StateMap,
29}
30
31impl Context {
32    pub(crate) fn new(mut states: StateMap) -> Self {
33        #[cfg(feature = "turso")]
34        {
35            use crate::extensions::turso::TursoDispatcher;
36            states.insert(TursoDispatcher::new());
37        }
38
39        Self {
40            sink: Mutex::new(None),
41            pending_requests: Arc::new(DashMap::new()),
42            state: states,
43        }
44    }
45}
46
47impl Context {
48    pub(crate) async fn send_obj<T, R>(
49        &self,
50        action: String,
51        obj: T,
52    ) -> Result<ApiResponse<R>, FlowError>
53    where
54        T: serde::Serialize,
55        R: for<'de> serde::Deserialize<'de>,
56    {
57        // Generate random echo string
58        let echo = uuid::Uuid::new_v4().to_string();
59
60        // Create oneshot channel for this specific request
61        let (tx, rx) = oneshot::channel();
62
63        // Register the request BEFORE sending (lock-free)
64        self.pending_requests.insert(echo.clone(), tx);
65
66        // Build and send the message
67        let msg = json!({
68            "action": action,
69            "params": obj,
70            "echo": echo,
71        });
72        let text = serde_json::to_string(&msg)?;
73        let msg = Message::Text(text.into());
74
75        // Send message and release lock immediately
76        {
77            let mut sink = self.sink.lock().await;
78            let sink = sink.as_mut().ok_or(FlowError::NoConnection)?;
79            sink.send(msg).await?;
80        }
81
82        // Wait for response with timeout
83        let response = tokio::time::timeout(std::time::Duration::from_secs(30), rx).await;
84
85        match response {
86            Ok(Ok(data)) => Ok(serde_json::from_str(&data)?),
87            Ok(Err(_)) => Err(FlowError::NoResponse), // Sender dropped
88            Err(_) => {
89                // Timeout occurred, clean up the pending request (lock-free)
90                self.pending_requests.remove(&echo);
91                Err(FlowError::Timeout(30000))
92            }
93        }
94    }
95
96    pub(crate) fn on_recv_echo(&self, echo: String, data: String) {
97        let pending_requests = self.pending_requests.clone();
98        tokio::spawn(async move {
99            // DashMap::remove returns Option<(K, V)>, extract the sender
100            if let Some((_, tx)) = pending_requests.remove(&echo) {
101                let _ = tx.send(data); // Ignore error if receiver dropped
102            }
103            // If echo not found, response arrived after timeout - silently ignore
104        });
105    }
106
107    pub async fn get_self_id(&self) -> Result<i64, FlowError> {
108        let info = self.get_login_info().await?;
109        Ok(info.user_id)
110    }
111}
112
113pub type BotContext = Arc<Context>;
114
115#[async_trait]
116impl FromEvent for BotContext {
117    async fn from_event(context: BotContext, _: BotEvent) -> Option<Self> {
118        Some(context)
119    }
120}
121
122pub(crate) struct StateMap {
123    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
124}
125
126impl StateMap {
127    pub(crate) fn new() -> Self {
128        Self {
129            map: HashMap::new(),
130        }
131    }
132
133    pub(crate) fn insert<T: Any + Send + Sync>(&mut self, state: T) {
134        self.map.insert(TypeId::of::<T>(), Arc::new(state));
135    }
136
137    pub(crate) fn get<T: Any + Send + Sync>(&self) -> Option<Arc<T>> {
138        self.map
139            .get(&TypeId::of::<T>())
140            .and_then(|state| Arc::clone(state).downcast::<T>().ok())
141    }
142}