use crate::{
ContextTokens, DaemonClient, GatewayConfig, GatewayMessage, StreamAccumulator, StreamResult,
UserIdMap,
};
use std::{collections::HashMap, path::Path, sync::Arc};
use tokio::sync::mpsc;
use wcore::protocol::message::{
ClientMessage, ReplyToAsk, ServerMessage, StreamMsg, server_message,
};
pub async fn run(daemon_socket: &str, config: &GatewayConfig) -> anyhow::Result<()> {
let client = Arc::new(DaemonClient::new(Path::new(daemon_socket)));
let agents_dir = wcore::paths::CONFIG_DIR.join(wcore::paths::AGENTS_DIR);
let default_agent = crate::resolve_default_agent(&agents_dir);
tracing::info!(agent = %default_agent, "wechat gateway starting");
if let Some(wc) = &config.wechat {
if wc.token.is_empty() {
tracing::warn!(platform = "wechat", "token is empty, skipping");
} else {
spawn_wechat(wc, default_agent, client).await;
}
} else {
tracing::warn!(platform = "wechat", "no wechat config provided");
}
tokio::signal::ctrl_c().await?;
tracing::info!("wechat gateway shutting down");
Ok(())
}
async fn spawn_wechat(
wc: &gateway::config::WechatConfig,
agent: String,
client: Arc<DaemonClient>,
) {
let (tx, rx) = mpsc::unbounded_channel::<GatewayMessage>();
let ctx_tokens: ContextTokens = Arc::new(std::sync::Mutex::new(HashMap::new()));
let user_ids: UserIdMap = Arc::new(std::sync::Mutex::new(HashMap::new()));
let http = reqwest::Client::new();
let base_url = wc.base_url.clone();
let token = wc.token.clone();
let poll_ctx = ctx_tokens.clone();
let poll_ids = user_ids.clone();
tokio::spawn(async move {
crate::poll_loop(http, base_url, token, tx, poll_ctx, poll_ids).await;
});
let allowed: std::collections::HashSet<String> = wc.allowed_users.iter().cloned().collect();
if !allowed.is_empty() {
tracing::info!(
platform = "wechat",
count = allowed.len(),
"user whitelist active"
);
}
let base_url = wc.base_url.clone();
let token = wc.token.clone();
tokio::spawn(wechat_loop(
rx, agent, client, ctx_tokens, user_ids, allowed, base_url, token,
));
tracing::info!(platform = "wechat", "channel transport started");
}
struct ChatStream {
handle: tokio::task::JoinHandle<StreamResult>,
session_id: Option<u64>,
reply_tx: mpsc::UnboundedSender<String>,
}
impl ChatStream {
fn is_finished(&self) -> bool {
self.handle.is_finished()
}
}
async fn reap_chat(chat: ChatStream) -> Option<u64> {
match chat.handle.await {
Ok(StreamResult::Ok { session_id }) => Some(session_id),
_ => chat.session_id,
}
}
#[allow(clippy::too_many_arguments)]
async fn wechat_loop(
mut rx: mpsc::UnboundedReceiver<GatewayMessage>,
agent: String,
client: Arc<DaemonClient>,
ctx_tokens: ContextTokens,
user_ids: UserIdMap,
allowed_users: std::collections::HashSet<String>,
base_url: String,
token: String,
) {
let mut chats: HashMap<i64, ChatStream> = HashMap::new();
let mut sessions: HashMap<i64, u64> = HashMap::new();
let http = reqwest::Client::new();
while let Some(msg) = rx.recv().await {
let chat_id = msg.chat_id;
let content = msg.content.clone();
if !allowed_users.is_empty() {
let user_id = user_ids.lock().unwrap().get(&chat_id).cloned();
if let Some(ref uid) = user_id
&& !allowed_users.contains(uid)
{
tracing::debug!(user_id = %uid, chat_id, "dropping non-allowed user");
continue;
}
}
tracing::info!(agent = %agent, chat_id, "wechat dispatch");
if let Some(chat_stream) = chats.get(&chat_id) {
if chat_stream.is_finished() {
let chat_stream = chats.remove(&chat_id).unwrap();
if let Some(sid) = reap_chat(chat_stream).await {
sessions.insert(chat_id, sid);
}
} else {
let _ = chat_stream.reply_tx.send(content);
continue;
}
}
let session = sessions.get(&chat_id).copied();
let sender = user_ids
.lock()
.unwrap()
.get(&chat_id)
.cloned()
.unwrap_or_default();
let (reply_tx, reply_rx) = mpsc::unbounded_channel();
let handle = {
let client = client.clone();
let agent = agent.clone();
let http = http.clone();
let base_url = base_url.clone();
let token = token.clone();
let ctx_tokens = ctx_tokens.clone();
let user_ids = user_ids.clone();
let sender = sender.clone();
tokio::spawn(async move {
let result = wx_stream(
&http,
&client,
&agent,
chat_id,
&content,
&sender,
session,
reply_rx,
&base_url,
&token,
&ctx_tokens,
&user_ids,
)
.await;
match result {
StreamResult::SessionError if session.is_some() => {
tracing::warn!(agent = %&agent, chat_id, "session error, retrying");
let (_retry_tx, retry_rx) = mpsc::unbounded_channel();
wx_stream(
&http,
&client,
&agent,
chat_id,
&content,
&sender,
None,
retry_rx,
&base_url,
&token,
&ctx_tokens,
&user_ids,
)
.await
}
other => other,
}
})
};
chats.insert(
chat_id,
ChatStream {
handle,
session_id: session,
reply_tx,
},
);
}
tracing::info!(platform = "wechat", "channel loop ended");
}
#[allow(clippy::too_many_arguments)]
async fn wx_stream(
http: &reqwest::Client,
client: &DaemonClient,
agent: &str,
chat_id: i64,
content: &str,
sender: &str,
session: Option<u64>,
mut reply_rx: mpsc::UnboundedReceiver<String>,
base_url: &str,
token: &str,
ctx_tokens: &ContextTokens,
user_ids: &UserIdMap,
) -> StreamResult {
tracing::info!(agent, chat_id, %sender, ?session, "starting stream");
let client_msg = ClientMessage::from(StreamMsg {
agent: agent.to_string(),
content: content.to_string(),
session,
sender: Some(sender.to_string()),
cwd: None,
});
let mut server_rx = client.send(client_msg).await;
let mut acc = StreamAccumulator::new();
loop {
tokio::select! {
server_msg = server_rx.recv() => {
match server_msg {
Some(ServerMessage { msg: Some(server_message::Msg::Stream(event)) }) => {
acc.push(&event);
if let Some(questions) = acc.take_pending_questions() {
let question_text = questions
.iter()
.map(|q| format!("{}: {}", q.header, q.question))
.collect::<Vec<_>>()
.join("\n");
let to_user = user_ids.lock().unwrap().get(&chat_id).cloned();
let ctx = ctx_tokens.lock().unwrap().get(
to_user.as_deref().unwrap_or("")
).cloned();
if let (Some(to), Some(ct)) = (to_user, ctx) {
let _ = crate::api::send_message(
http, base_url, token, &to, &ct, &question_text,
).await;
}
}
if acc.is_done() {
break;
}
}
Some(ServerMessage { msg: Some(server_message::Msg::Error(err)) }) => {
acc.set_error(err.message);
break;
}
Some(_) => {}
None => break,
}
}
reply = reply_rx.recv() => {
if let Some(reply_content) = reply {
if let Some(session_id) = acc.session() {
let reply_msg = ClientMessage::from(ReplyToAsk {
session: session_id,
content: reply_content,
});
let _ = client.send(reply_msg).await;
}
}
}
}
}
if let Some(err) = acc.error() {
tracing::warn!(agent, chat_id, "stream error: {err}");
let to_user = user_ids.lock().unwrap().get(&chat_id).cloned();
let ctx = ctx_tokens
.lock()
.unwrap()
.get(to_user.as_deref().unwrap_or(""))
.cloned();
if let (Some(to), Some(ct)) = (to_user, ctx) {
let _ =
crate::api::send_message(http, base_url, token, &to, &ct, &format!("Error: {err}"))
.await;
}
return if session.is_some() {
StreamResult::SessionError
} else {
StreamResult::Failed
};
}
let final_text = acc.render();
if !final_text.is_empty() {
tracing::info!(agent, chat_id, len = final_text.len(), "sending reply");
let to_user = user_ids.lock().unwrap().get(&chat_id).cloned();
let ctx = ctx_tokens
.lock()
.unwrap()
.get(to_user.as_deref().unwrap_or(""))
.cloned();
if let (Some(to), Some(ct)) = (to_user, ctx) {
if let Err(e) =
crate::api::send_message(http, base_url, token, &to, &ct, &final_text).await
{
tracing::warn!(agent, chat_id, "failed to send reply: {e}");
} else {
tracing::info!(agent, chat_id, "reply sent");
}
} else {
tracing::warn!(agent, chat_id, "no user_id or context_token for reply");
}
} else {
tracing::debug!(agent, chat_id, "stream ended with empty response");
}
match acc.session() {
Some(session_id) => {
tracing::info!(agent, chat_id, session_id, "stream completed");
StreamResult::Ok { session_id }
}
None => {
tracing::warn!(agent, chat_id, "stream completed without session");
StreamResult::Failed
}
}
}