1use std::{
2 any::{Any, TypeId},
3 collections::HashMap,
4 sync::Arc,
5};
6
7use async_trait::async_trait;
8use futures::{SinkExt, lock::Mutex, stream::SplitSink};
9use serde_json::json;
10use tokio::{net::TcpStream, sync::broadcast::Sender};
11use tokio_tungstenite::{MaybeTlsStream, WebSocketStream, tungstenite::Message};
12
13use crate::{
14 api::{ApiResponse, api_ext::ApiExt},
15 error::FlowError,
16 event::BotEvent,
17};
18
19use super::extract::FromEvent;
20
21pub struct Context {
22 pub(crate) sink: Mutex<Option<SplitSink<WebSocketStream<MaybeTlsStream<TcpStream>>, Message>>>,
23 sender: Sender<(String, String)>,
24 pub(crate) state: StateMap,
25}
26
27impl Context {
28 pub(crate) fn new(states: StateMap) -> Self {
29 let (sender, _) = tokio::sync::broadcast::channel(10);
30 Self {
31 sink: Mutex::new(None),
32 sender,
33 state: states,
34 }
35 }
36}
37
38impl Context {
39 pub(crate) async fn send_obj<T, R>(
40 &self,
41 action: String,
42 obj: T,
43 ) -> Result<ApiResponse<R>, FlowError>
44 where
45 T: serde::Serialize,
46 R: for<'de> serde::Deserialize<'de>,
47 {
48 let echo = uuid::Uuid::new_v4().to_string();
50 let msg = json!({
51 "action": action,
52 "params": obj,
53 "echo": echo,
54 });
55 let text = serde_json::to_string(&msg)?;
56 let msg = Message::Text(text.into());
57 let mut sink = self.sink.lock().await;
58 let sink = sink.as_mut().ok_or(FlowError::NoConnection)?;
59 sink.send(msg).await?;
60
61 let mut recv = self.sender.subscribe();
62 while let Ok((e, r)) = recv.recv().await {
63 if e == echo {
64 return Ok(serde_json::from_str(&r)?);
65 }
66 }
67 Err(FlowError::NoResponse)
68 }
69
70 pub(crate) fn on_recv_echo(&self, echo: String, data: String) {
71 let _ = self.sender.send((echo, data));
72 }
73
74 pub async fn get_self_id(&self) -> Result<i64, FlowError> {
75 let info = self.get_login_info().await?;
76 Ok(info.user_id)
77 }
78}
79
80pub type BotContext = Arc<Context>;
81
82#[async_trait]
83impl FromEvent for BotContext {
84 async fn from_event(context: BotContext, _: BotEvent) -> Option<Self> {
85 Some(context)
86 }
87}
88
89pub(crate) struct StateMap {
90 map: HashMap<TypeId, Arc<dyn Any + Send + Sync>>,
91}
92
93impl StateMap {
94 pub(crate) fn new() -> Self {
95 Self {
96 map: HashMap::new(),
97 }
98 }
99
100 pub(crate) fn insert<T: 'static + Any + Send + Sync>(&mut self, state: T) {
101 self.map.insert(TypeId::of::<T>(), Arc::new(state));
102 }
103
104 pub(crate) fn get<T: Any>(&self) -> Option<Arc<T>> {
105 self.map
106 .get(&TypeId::of::<T>())
107 .and_then(|state| downcast_arc::<T>(state.clone()))
108 }
109}
110
111fn downcast_arc<T: Any>(arc: Arc<dyn Any>) -> Option<Arc<T>> {
112 if arc.is::<T>() {
113 let ptr = Arc::into_raw(arc) as *const T;
114 let arc = unsafe { Arc::from_raw(ptr) };
115 Some(arc)
116 } else {
117 None
118 }
119}