use std::collections::HashSet;
use std::sync::atomic::{AtomicU64, AtomicUsize, Ordering};
use std::sync::Arc;
use super::OBC;
use crate::ah::GetSelfs;
use crate::prelude::{Bot, Version};
use crate::structs::Selft;
use crate::util::{Echo, EchoInner, EchoS, GetSelf, ProtocolItem};
use crate::{ActionHandler, EventHandler, GetStatus, GetVersion, OneBot};
use crate::{WalleError, WalleResult};
use async_trait::async_trait;
use dashmap::DashMap;
use tokio::sync::{mpsc, oneshot};
use tokio::task::JoinHandle;
use tracing::{info, warn};
#[cfg(feature = "http")]
mod app_http;
#[cfg(feature = "websocket")]
mod app_ws;
pub(crate) type EchoMap<R> = Arc<DashMap<EchoS, oneshot::Sender<R>>>;
pub struct AppOBC<A, R> {
pub(crate) echos: EchoMap<R>, pub(crate) seq: AtomicU64, pub(crate) bots: Arc<BotMap<A>>, }
impl<A, R> AppOBC<A, R> {
pub fn new() -> Self {
Default::default()
}
}
impl<A, R> Default for AppOBC<A, R> {
fn default() -> Self {
Self {
echos: Arc::new(DashMap::new()),
seq: AtomicU64::default(),
bots: Arc::new(Default::default()),
}
}
}
impl<A, R> AppOBC<A, R> {
pub(crate) fn next_seg(&self) -> EchoS {
EchoS(Some(EchoInner::S(
self.seq
.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
.to_string(),
)))
}
}
#[async_trait]
impl<E, A, R> ActionHandler<E, A, R> for AppOBC<A, R>
where
E: ProtocolItem + Clone + GetSelf,
A: ProtocolItem + GetSelf,
R: ProtocolItem,
{
type Config = crate::config::AppConfig;
async fn start<AH, EH>(
&self,
ob: &Arc<OneBot<AH, EH>>,
config: crate::config::AppConfig,
) -> WalleResult<Vec<JoinHandle<()>>>
where
AH: ActionHandler<E, A, R> + Send + Sync + 'static,
EH: EventHandler<E, A, R> + Send + Sync + 'static,
{
let mut tasks = vec![];
#[cfg(feature = "websocket")]
{
self.wsr(ob, config.websocket_rev, &mut tasks).await?;
self.ws(ob, config.websocket, &mut tasks).await?;
}
#[cfg(feature = "http")]
{
self.webhook(ob, config.http_webhook, &mut tasks).await?;
self.http(ob, config.http, &mut tasks).await?;
}
Ok(tasks)
}
async fn call<AH, EH>(&self, action: A, _ob: &Arc<OneBot<AH, EH>>) -> WalleResult<R>
where
AH: ActionHandler<E, A, R> + Send + Sync + 'static,
EH: EventHandler<E, A, R> + Send + Sync + 'static,
{
match self.bots.get_bot(&action.get_self()) {
Some(action_txs) => {
let (tx, rx) = oneshot::channel();
let seq = self.next_seg();
self.echos.insert(seq.clone(), tx);
action_txs
.first()
.unwrap() .send(seq.pack(action))
.map_err(|e| {
warn!(target: super::OBC, "send action error: {}", e);
WalleError::Other(e.to_string())
})?;
match tokio::time::timeout(std::time::Duration::from_secs(10), rx).await {
Ok(Ok(res)) => Ok(res),
Ok(Err(e)) => {
warn!(target: super::OBC, "resp recv error: {:?}", e);
Err(WalleError::Other(e.to_string()))
}
Err(_) => {
warn!(target: super::OBC, "resp timeout");
Err(WalleError::Other("resp timeout".to_string()))
}
}
}
None => {
warn!(target: super::OBC, "bot not found");
Err(WalleError::BotNotExist)
}
}
}
}
#[derive(Debug)]
pub(crate) struct BotMap<A> {
pub(crate) conn_seq: AtomicUsize,
pub(crate) bots: DashMap<Selft, (String, Vec<mpsc::UnboundedSender<Echo<A>>>)>,
pub(crate) conns: DashMap<usize, (mpsc::UnboundedSender<Echo<A>>, HashSet<Selft>)>,
}
impl<A> Default for BotMap<A> {
fn default() -> Self {
Self {
conn_seq: AtomicUsize::default(),
bots: DashMap::default(),
conns: DashMap::default(),
}
}
}
impl<A> BotMap<A> {
fn new_connect(&self) -> (usize, mpsc::UnboundedReceiver<Echo<A>>) {
let seq = self.conn_seq.fetch_add(1, Ordering::Relaxed);
let (tx, rx) = mpsc::unbounded_channel();
self.conns.insert(seq, (tx, HashSet::default()));
(seq, rx)
}
fn connect_closs(&self, tx_seq: &usize) {
if let Some(selfts) = self.conns.remove(tx_seq) {
for selft in selfts.1 .1 {
let mut bot = self.bots.get_mut(&selft).unwrap();
bot.value_mut()
.1
.retain(|htx| !htx.same_channel(&selfts.1 .0));
if bot.value().1.is_empty() {
drop(bot);
self.bots.remove(&selft);
}
}
}
}
fn connect_update(&self, tx_seq: &usize, bots: Vec<Bot>, implt: &str) {
let mut get = self.conns.get_mut(tx_seq).unwrap();
let tx = get.0.clone();
let selfts = &mut get.1;
for bot in bots {
match (bot.online, selfts.contains(&bot.selft)) {
(true, false) => {
selfts.insert(bot.selft.clone());
self.bots
.entry(bot.selft.clone())
.or_insert((implt.to_string(), vec![]))
.1
.push(tx.clone());
info!(
target: OBC,
"New Bot connected: {}-{}", bot.selft.platform, bot.selft.user_id
);
}
(false, true) => {
selfts.remove(&bot.selft);
if let Some(mut bots) = self.bots.get_mut(&bot.selft) {
bots.value_mut().1.retain(|htx| !htx.same_channel(&tx));
if bots.1.is_empty() {
drop(bots);
self.bots.remove(&bot.selft);
}
}
info!(
target: OBC,
"Bot disconnected: {}-{}", bot.selft.platform, bot.selft.user_id
);
}
_ => {}
}
}
}
fn get_bot(&self, bot: &Selft) -> Option<Vec<mpsc::UnboundedSender<Echo<A>>>> {
self.bots.get(bot).as_deref().cloned().map(|v| v.1)
}
fn selfts(&self) -> Vec<Selft> {
self.bots.iter().map(|i| i.key().clone()).collect()
}
}
#[async_trait]
impl<A, R> GetSelfs for AppOBC<A, R>
where
A: Send + Sync,
R: Send + Sync,
{
async fn get_selfs(&self) -> Vec<Selft> {
self.bots.selfts()
}
async fn get_impl(&self, selft: &Selft) -> String {
self.bots
.bots
.get(selft)
.map(|v| v.value().0.clone())
.unwrap_or_default()
}
}
#[async_trait]
impl<A, R> GetStatus for AppOBC<A, R>
where
A: Send + Sync,
R: Send + Sync,
{
async fn is_good(&self) -> bool {
true
}
}
impl<A, R> GetVersion for AppOBC<A, R> {
fn get_version(&self) -> Version {
Version {
implt: super::OBC.to_owned(),
version: crate::VERSION.to_owned(),
onebot_version: "12".to_owned(),
}
}
}
#[test]
fn test_bot_map() {
let map = BotMap::<crate::action::Action>::default();
let (seq, _) = map.new_connect();
assert_eq!(seq, 0);
let (seq, _) = map.new_connect();
assert_eq!(seq, 1);
assert_eq!(
map.conns
.iter()
.map(|i| i.key().clone())
.collect::<HashSet<_>>(),
HashSet::from([1, 0])
);
let self0 = Selft {
platform: "".to_owned(),
user_id: "0".to_owned(),
};
let self1 = Selft {
platform: "".to_owned(),
user_id: "1".to_owned(),
};
assert!(map.bots.get(&self0).is_none());
assert!(map.conns.get(&0).unwrap().1.len() == 1);
assert_eq!(map.bots.get(&self1).unwrap().1.len(), 1);
assert!(map.get_bot(&self1).is_some());
}