use super::types::*;
use async_trait::async_trait;
use dashmap::DashMap;
use std::sync::Arc;
use std::time::{Duration, Instant};
use tokio::sync::Mutex;
use tracing::info;
#[derive(Clone)]
pub struct SessionConfig {
pub timeout: Duration,
pub reset_keywords: Vec<String>,
pub command_prefix: Option<String>,
pub reset_commands: Vec<String>,
pub reset_reply: String,
}
impl Default for SessionConfig {
fn default() -> Self {
Self {
timeout: Duration::from_secs(60 * 60), reset_keywords: vec![
"重置对话".into(),
"新对话".into(),
"清除记忆".into(),
"重新开始".into(),
],
command_prefix: Some("/".into()),
reset_commands: vec!["reset".into(), "clear".into(), "new".into()],
reset_reply: "✅ 对话已重置,请开始新的对话。".into(),
}
}
}
impl SessionConfig {
pub fn with_timeout_minutes(mut self, minutes: u64) -> Self {
self.timeout = Duration::from_secs(minutes * 60);
self
}
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
pub fn with_reset_keywords(mut self, keywords: Vec<String>) -> Self {
self.reset_keywords = keywords;
self
}
pub fn add_reset_keyword(mut self, keyword: impl Into<String>) -> Self {
self.reset_keywords.push(keyword.into());
self
}
pub fn with_command_prefix(mut self, prefix: Option<String>) -> Self {
self.command_prefix = prefix;
self
}
pub fn with_reset_commands(mut self, commands: Vec<String>) -> Self {
self.reset_commands = commands;
self
}
pub fn with_reset_reply(mut self, reply: impl Into<String>) -> Self {
self.reset_reply = reply.into();
self
}
pub fn is_reset(&self, text: &str) -> bool {
let trimmed = text.trim();
let lower = trimmed.to_lowercase();
if self
.reset_keywords
.iter()
.any(|kw| lower == kw.to_lowercase())
{
return true;
}
if let Some(ref prefix) = self.command_prefix
&& let Some(cmd) = trimmed.strip_prefix(prefix.as_str())
{
let cmd = cmd.trim();
let cmd_lower = cmd.to_lowercase();
if self
.reset_commands
.iter()
.any(|c| cmd_lower == c.to_lowercase())
{
return true;
}
}
false
}
}
type SessionKey = (String, String);
struct Session {
handler: Box<dyn MessageHandler>,
last_active: Instant,
}
pub trait SessionFactory: Send + Sync {
fn create(&self) -> Box<dyn MessageHandler>;
}
impl<F> SessionFactory for F
where
F: Fn() -> Box<dyn MessageHandler> + Send + Sync,
{
fn create(&self) -> Box<dyn MessageHandler> {
self()
}
}
pub struct SessionEndInfo {
pub channel_id: String,
pub sender_id: String,
pub reason: SessionEndReason,
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum SessionEndReason {
CommandReset,
TimeoutReplaced,
}
pub struct SessionHandler {
config: SessionConfig,
factory: Arc<dyn SessionFactory>,
sessions: DashMap<SessionKey, Arc<Mutex<Session>>>,
on_session_end: Option<Arc<dyn Fn(SessionEndInfo) + Send + Sync>>,
}
impl SessionHandler {
pub fn new(config: SessionConfig, factory: impl SessionFactory + 'static) -> Self {
Self {
config,
factory: Arc::new(factory),
sessions: DashMap::new(),
on_session_end: None,
}
}
pub fn with_defaults(factory: impl SessionFactory + 'static) -> Self {
Self::new(SessionConfig::default(), factory)
}
pub fn with_on_session_end<F>(mut self, callback: F) -> Self
where
F: Fn(SessionEndInfo) + Send + Sync + 'static,
{
self.on_session_end = Some(Arc::new(callback));
self
}
pub fn active_sessions(&self) -> usize {
self.sessions.len()
}
fn get_or_create(&self, key: &SessionKey) -> Arc<Mutex<Session>> {
let handler = self.factory.clone();
self.sessions
.entry(key.clone())
.or_insert_with(|| {
Arc::new(Mutex::new(Session {
handler: handler.create(),
last_active: Instant::now(),
}))
})
.clone()
}
fn notify_session_end(&self, channel_id: String, sender_id: String, reason: SessionEndReason) {
if let Some(ref callback) = self.on_session_end {
callback(SessionEndInfo {
channel_id,
sender_id,
reason,
});
}
}
}
#[async_trait]
impl MessageHandler for SessionHandler {
async fn handle(&self, msg: InboundMessage) -> echo_core::error::Result<OutboundMessage> {
let key = (msg.channel_id.clone(), msg.sender_id.clone());
if self.config.is_reset(&msg.text) {
if let Some((old_key, _old_session)) = self.sessions.remove(&key) {
self.notify_session_end(old_key.0, old_key.1, SessionEndReason::CommandReset);
}
info!(
"Session reset by command: ({}, {})",
msg.channel_id, msg.sender_id
);
return Ok(OutboundMessage::new(
&msg.channel_id,
&msg.sender_id,
msg.chat_type,
&self.config.reset_reply,
));
}
let session = self.get_or_create(&key);
let mut guard = session.lock().await;
if guard.last_active.elapsed() >= self.config.timeout {
info!(
"Session timeout for ({}, {}), elapsed {:?}",
msg.channel_id,
msg.sender_id,
guard.last_active.elapsed()
);
self.notify_session_end(
msg.channel_id.clone(),
msg.sender_id.clone(),
SessionEndReason::TimeoutReplaced,
);
guard.handler = self.factory.create();
}
guard.last_active = Instant::now();
guard.handler.handle(msg).await
}
async fn reply(&self, _msg: OutboundMessage) -> echo_core::error::Result<()> {
Ok(())
}
}