flow_bot/base/
context.rs

1use std::{
2    any::{Any, TypeId},
3    collections::HashMap,
4    sync::Arc,
5};
6
7use async_trait::async_trait;
8use futures::{SinkExt, lock::Mutex, stream::SplitSink};
9use serde_json::json;
10use tokio::{net::TcpStream, sync::broadcast::Sender};
11use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
12
13use crate::{
14    api::{ApiResponse, api_ext::ApiExt},
15    error::FlowError,
16    event::BotEvent,
17};
18
19use super::extract::FromEvent;
20
21pub struct Context {
22    pub(crate) sink: Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
23    sender: Sender<(String, String)>,
24    pub(crate) state: StateMap,
25}
26
27impl Context {
28    pub(crate) fn new(states: StateMap) -> Self {
29        let (sender, _) = tokio::sync::broadcast::channel(10);
30        Self {
31            sink: Mutex::new(None),
32            sender,
33            state: states,
34        }
35    }
36}
37
38impl Context {
39    pub(crate) async fn send_obj<T, R>(
40        &self,
41        action: String,
42        obj: T,
43    ) -> Result<ApiResponse<R>, FlowError>
44    where
45        T: serde::Serialize,
46        R: for<'de> serde::Deserialize<'de>,
47    {
48        // generate random echo string
49        let echo = uuid::Uuid::new_v4().to_string();
50        let msg = json!({
51            "action": action,
52            "params": obj,
53            "echo": echo,
54        });
55        let text = serde_json::to_string(&msg)?;
56        let msg = Message::Text(text.into());
57        let mut sink = self.sink.lock().await;
58        let sink = sink.as_mut().ok_or(FlowError::NoConnection)?;
59        sink.send(msg).await?;
60
61        let mut recv = self.sender.subscribe();
62        while let Ok((e, r)) = recv.recv().await {
63            if e == echo {
64                return Ok(serde_json::from_str(&r)?);
65            }
66        }
67        Err(FlowError::NoResponse)
68    }
69
70    pub(crate) fn on_recv_echo(&self, echo: String, data: String) {
71        let _ = self.sender.send((echo, data));
72    }
73
74    pub async fn get_self_id(&self) -> Result<i64, FlowError> {
75        let info = self.get_login_info().await?;
76        Ok(info.user_id)
77    }
78}
79
80pub type BotContext = Arc<Context>;
81
82#[async_trait]
83impl FromEvent for BotContext {
84    async fn from_event(context: BotContext, _: BotEvent) -> Option<Self> {
85        Some(context)
86    }
87}
88
89pub(crate) struct StateMap {
90    map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
91}
92
93impl StateMap {
94    pub(crate) fn new() -> Self {
95        Self {
96            map: HashMap::new(),
97        }
98    }
99
100    pub(crate) fn insert<T: 'static + Any + Send + Sync>(&mut self, state: T) {
101        self.map.insert(TypeId::of::<T>(), Arc::new(state));
102    }
103
104    pub(crate) fn get<T: Any>(&self) -> Option<Arc<T>> {
105        self.map
106            .get(&TypeId::of::<T>())
107            .and_then(|state| downcast_arc::<T>(state.clone()))
108    }
109}
110
111fn downcast_arc<T: Any>(arc: Arc<dyn Any>) -> Option<Arc<T>> {
112    if arc.is::<T>() {
113        let ptr = Arc::into_raw(arc) as *const T;
114        let arc = unsafe { Arc::from_raw(ptr) };
115        Some(arc)
116    } else {
117        None
118    }
119}