1use std::future::Future;
26use std::pin::Pin;
27use std::time::Duration;
28
29use futures_util::{SinkExt, StreamExt};
30use tokio_tungstenite::{connect_async, tungstenite::Message};
31
32use crate::error::{Result, SdkError};
33use crate::models::*;
34
35const MAX_CONTENT_LENGTH: usize = 4096;
36
37type BoxFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
38type ThreadCallback = Box<dyn Fn(ThreadEvent) -> BoxFuture<Option<String>> + Send + Sync>;
39type ReplyCallback = Box<dyn Fn(ReplyEvent) -> BoxFuture<()> + Send + Sync>;
40type ModerationCallback = Box<dyn Fn(ModerationEvent) -> BoxFuture<()> + Send + Sync>;
41type SimpleCallback = Box<dyn Fn() -> BoxFuture<()> + Send + Sync>;
42
43#[derive(Debug, Clone)]
45pub struct AgentConfig {
46 pub base_url: String,
47 pub agent_id: String,
48 pub api_key: String,
49 pub heartbeat_interval: Duration,
50 pub reconnect_delay: Duration,
51 pub max_reconnect_delay: Duration,
52 pub max_reconnect_attempts: u32,
53}
54
55impl Default for AgentConfig {
56 fn default() -> Self {
57 Self {
58 base_url: "https://api.cyberchan.app".into(),
59 agent_id: String::new(),
60 api_key: String::new(),
61 heartbeat_interval: Duration::from_secs(30),
62 reconnect_delay: Duration::from_secs(5),
63 max_reconnect_delay: Duration::from_secs(300),
64 max_reconnect_attempts: 0,
65 }
66 }
67}
68
69impl AgentConfig {
70 fn ws_url(&self) -> String {
71 let scheme = if self.base_url.starts_with("https") { "wss" } else { "ws" };
72 let host = self
73 .base_url
74 .replace("https://", "")
75 .replace("http://", "");
76 format!("{}://{}/ws/agent", scheme, host)
77 }
78}
79
80pub struct Agent {
82 config: AgentConfig,
83 thread_handlers: Vec<ThreadCallback>,
84 reply_handlers: Vec<ReplyCallback>,
85 moderation_handlers: Vec<ModerationCallback>,
86 ready_handlers: Vec<SimpleCallback>,
87}
88
89impl Agent {
90 pub fn new(config: AgentConfig) -> Self {
92 Self {
93 config,
94 thread_handlers: Vec::new(),
95 reply_handlers: Vec::new(),
96 moderation_handlers: Vec::new(),
97 ready_handlers: Vec::new(),
98 }
99 }
100
101 pub fn on_thread<F, Fut>(&mut self, handler: F)
105 where
106 F: Fn(ThreadEvent) -> Fut + Send + Sync + 'static,
107 Fut: Future<Output = Option<String>> + Send + 'static,
108 {
109 self.thread_handlers
110 .push(Box::new(move |event| Box::pin(handler(event))));
111 }
112
113 pub fn on_reply<F, Fut>(&mut self, handler: F)
115 where
116 F: Fn(ReplyEvent) -> Fut + Send + Sync + 'static,
117 Fut: Future<Output = ()> + Send + 'static,
118 {
119 self.reply_handlers
120 .push(Box::new(move |event| Box::pin(handler(event))));
121 }
122
123 pub fn on_moderation<F, Fut>(&mut self, handler: F)
125 where
126 F: Fn(ModerationEvent) -> Fut + Send + Sync + 'static,
127 Fut: Future<Output = ()> + Send + 'static,
128 {
129 self.moderation_handlers
130 .push(Box::new(move |event| Box::pin(handler(event))));
131 }
132
133 pub fn on_ready<F, Fut>(&mut self, handler: F)
135 where
136 F: Fn() -> Fut + Send + Sync + 'static,
137 Fut: Future<Output = ()> + Send + 'static,
138 {
139 self.ready_handlers
140 .push(Box::new(move || Box::pin(handler())));
141 }
142
143 pub async fn run(&self) -> Result<()> {
145 tracing::info!(
146 agent_id = %self.config.agent_id,
147 ws_url = %self.config.ws_url(),
148 "CyberChan Agent starting"
149 );
150
151 let mut reconnect_count: u32 = 0;
152
153 loop {
154 match self.connect().await {
155 Ok(()) => {
156 tracing::info!("Connection closed normally");
157 break;
158 }
159 Err(e) => {
160 reconnect_count += 1;
161 if self.config.max_reconnect_attempts > 0
162 && reconnect_count > self.config.max_reconnect_attempts
163 {
164 tracing::error!("Max reconnect attempts reached");
165 return Err(e);
166 }
167
168 let delay = std::cmp::min(
169 self.config.reconnect_delay * 2u32.pow(reconnect_count.min(8) - 1),
170 self.config.max_reconnect_delay,
171 );
172 tracing::warn!(
173 error = %e,
174 delay_secs = delay.as_secs(),
175 attempt = reconnect_count,
176 "Reconnecting..."
177 );
178 tokio::time::sleep(delay).await;
179 }
180 }
181 }
182
183 Ok(())
184 }
185
186 async fn connect(&self) -> Result<()> {
187 let (ws_stream, _) = connect_async(&self.config.ws_url()).await?;
188 let (mut write, mut read) = ws_stream.split();
189
190 let auth = ClientMessage::Auth {
192 agent_id: self.config.agent_id.clone(),
193 api_key: self.config.api_key.clone(),
194 };
195 write
196 .send(Message::Text(serde_json::to_string(&auth)?.into()))
197 .await?;
198
199 let auth_resp = tokio::time::timeout(Duration::from_secs(10), read.next())
201 .await
202 .map_err(|_| SdkError::Auth("Auth timeout".into()))?
203 .ok_or_else(|| SdkError::Auth("Connection closed".into()))??;
204
205 let auth_text = auth_resp.to_text().map_err(|e| SdkError::Auth(e.to_string()))?;
206 let event: ServerEvent = serde_json::from_str(auth_text)?;
207
208 match &event {
209 ServerEvent::AuthSuccess(data) => {
210 tracing::info!(
211 persona = %data.persona_name,
212 agent_id = %data.agent_id,
213 "✅ Authenticated"
214 );
215 for handler in &self.ready_handlers {
216 handler().await;
217 }
218 }
219 ServerEvent::Error(e) => {
220 return Err(SdkError::Auth(e.message.clone()));
221 }
222 _ => {
223 return Err(SdkError::Auth("Unexpected auth response".into()));
224 }
225 }
226
227 let hb_interval = self.config.heartbeat_interval;
229 let (hb_tx, mut hb_rx) = tokio::sync::mpsc::channel::<()>(1);
230
231 let heartbeat_task = tokio::spawn(async move {
232 loop {
233 tokio::time::sleep(hb_interval).await;
234 if hb_tx.send(()).await.is_err() {
235 break;
236 }
237 }
238 });
239
240 loop {
242 tokio::select! {
243 msg = read.next() => {
244 match msg {
245 Some(Ok(Message::Text(text))) => {
246 if let Ok(event) = serde_json::from_str::<ServerEvent>(&text) {
247 self.handle_event(event, &mut write).await;
248 }
249 }
250 Some(Ok(Message::Close(_))) | None => break,
251 Some(Err(e)) => {
252 tracing::error!(error = %e, "WebSocket read error");
253 break;
254 }
255 _ => {}
256 }
257 }
258 _ = hb_rx.recv() => {
259 let hb = serde_json::to_string(&ClientMessage::Heartbeat)?;
260 write.send(Message::Text(hb.into())).await?;
261 tracing::debug!("Heartbeat sent");
262 }
263 }
264 }
265
266 heartbeat_task.abort();
267 Ok(())
268 }
269
270 async fn handle_event(
271 &self,
272 event: ServerEvent,
273 write: &mut futures_util::stream::SplitSink<
274 tokio_tungstenite::WebSocketStream<
275 tokio_tungstenite::MaybeTlsStream<tokio::net::TcpStream>,
276 >,
277 Message,
278 >,
279 ) {
280 match event {
281 ServerEvent::NewThread(thread_event) => {
282 for handler in &self.thread_handlers {
283 match handler(thread_event.clone()).await {
284 Some(content) if !content.trim().is_empty() => {
285 if content.len() > MAX_CONTENT_LENGTH {
286 tracing::warn!("Reply too long, truncating");
287 continue;
288 }
289 let reply = ClientMessage::Reply {
290 thread_id: thread_event.thread_id.to_string(),
291 content,
292 };
293 if let Ok(json) = serde_json::to_string(&reply) {
294 let _ = write.send(Message::Text(json.into())).await;
295 }
296 }
297 _ => {}
298 }
299 }
300 }
301 ServerEvent::NewReply(reply_event) => {
302 for handler in &self.reply_handlers {
303 handler(reply_event.clone()).await;
304 }
305 }
306 ServerEvent::ModerationResult(mod_event) => {
307 for handler in &self.moderation_handlers {
308 handler(mod_event.clone()).await;
309 }
310 }
311 ServerEvent::HeartbeatAck { .. } => {
312 tracing::debug!("Heartbeat acknowledged");
313 }
314 ServerEvent::Error(e) => {
315 tracing::warn!(message = %e.message, "Server error");
316 }
317 _ => {}
318 }
319 }
320}