1use anyhow::Result;
2use aster::agents::{Agent, AgentEvent};
3use aster::conversation::message::Message as AsterMessage;
4use aster::session::session_manager::SessionType;
5use aster::session::SessionManager;
6use axum::response::Redirect;
7use axum::{
8 extract::{
9 ws::{Message, WebSocket, WebSocketUpgrade},
10 Query, Request, State,
11 },
12 http::{StatusCode, Uri},
13 middleware::{self, Next},
14 response::{Html, IntoResponse, Response},
15 routing::get,
16 Json, Router,
17};
18use base64::Engine;
19use futures::{sink::SinkExt, stream::StreamExt};
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22use std::{net::SocketAddr, sync::Arc};
23use tokio::sync::{Mutex, RwLock};
24use tower_http::cors::{AllowOrigin, Any, CorsLayer};
25use tracing::error;
26use webbrowser;
27
28type CancellationStore = Arc<RwLock<std::collections::HashMap<String, tokio::task::AbortHandle>>>;
29
30#[derive(Clone)]
31struct AppState {
32 agent: Arc<Agent>,
33 cancellations: CancellationStore,
34 auth_token: Option<String>,
35 ws_token: String,
36}
37
38#[derive(Serialize, Deserialize)]
39#[serde(tag = "type")]
40enum WebSocketMessage {
41 #[serde(rename = "message")]
42 Message {
43 content: String,
44 session_id: String,
45 timestamp: i64,
46 },
47 #[serde(rename = "cancel")]
48 Cancel { session_id: String },
49 #[serde(rename = "response")]
50 Response {
51 content: String,
52 role: String,
53 timestamp: i64,
54 },
55 #[serde(rename = "tool_request")]
56 ToolRequest {
57 id: String,
58 tool_name: String,
59 arguments: serde_json::Value,
60 },
61 #[serde(rename = "tool_response")]
62 ToolResponse {
63 id: String,
64 result: serde_json::Value,
65 is_error: bool,
66 },
67 #[serde(rename = "tool_confirmation")]
68 ToolConfirmation {
69 id: String,
70 tool_name: String,
71 arguments: serde_json::Value,
72 needs_confirmation: bool,
73 },
74 #[serde(rename = "error")]
75 Error { message: String },
76 #[serde(rename = "thinking")]
77 Thinking { message: String },
78 #[serde(rename = "context_exceeded")]
79 ContextExceeded { message: String },
80 #[serde(rename = "cancelled")]
81 Cancelled { message: String },
82 #[serde(rename = "complete")]
83 Complete { message: String },
84}
85
86async fn auth_middleware(
87 State(state): State<AppState>,
88 req: Request,
89 next: Next,
90) -> Result<Response, StatusCode> {
91 if req.uri().path() == "/api/health" {
92 return Ok(next.run(req).await);
93 }
94
95 let Some(ref expected_token) = state.auth_token else {
96 return Ok(next.run(req).await);
97 };
98
99 if let Some(auth_header) = req.headers().get("authorization") {
100 if let Ok(auth_str) = auth_header.to_str() {
101 if let Some(token) = auth_str.strip_prefix("Bearer ") {
102 if token == expected_token {
103 return Ok(next.run(req).await);
104 }
105 }
106
107 if let Some(basic_token) = auth_str.strip_prefix("Basic ") {
108 if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(basic_token) {
109 if let Ok(credentials) = String::from_utf8(decoded) {
110 if credentials.ends_with(expected_token) {
111 return Ok(next.run(req).await);
112 }
113 }
114 }
115 }
116 }
117 }
118
119 let mut response = Response::new("Authentication required".into());
120 *response.status_mut() = StatusCode::UNAUTHORIZED;
121 response.headers_mut().insert(
122 "WWW-Authenticate",
123 "Basic realm=\"Aster Web Interface\"".parse().unwrap(),
124 );
125 Ok(response)
126}
127
128pub async fn handle_web(
129 port: u16,
130 host: String,
131 open: bool,
132 auth_token: Option<String>,
133) -> Result<()> {
134 crate::logging::setup_logging(Some("aster-web"), None)?;
135
136 let config = aster::config::Config::global();
137
138 let provider_name: String = match config.get_aster_provider() {
139 Ok(p) => p,
140 Err(_) => {
141 eprintln!("No provider configured. Run 'aster configure' first");
142 std::process::exit(1);
143 }
144 };
145
146 let model: String = match config.get_aster_model() {
147 Ok(m) => m,
148 Err(_) => {
149 eprintln!("No model configured. Run 'aster configure' first");
150 std::process::exit(1);
151 }
152 };
153
154 let model_config = aster::model::ModelConfig::new(&model)?;
155
156 let init_session = SessionManager::create_session(
157 std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
158 "Web Agent Initialization".to_string(),
159 SessionType::Hidden,
160 )
161 .await?;
162
163 let agent = Agent::new();
164 let provider = aster::providers::create(&provider_name, model_config).await?;
165 agent.update_provider(provider, &init_session.id).await?;
166
167 let enabled_configs = aster::config::get_enabled_extensions();
168 for config in enabled_configs {
169 if let Err(e) = agent.add_extension(config.clone()).await {
170 eprintln!("Warning: Failed to load extension {}: {}", config.name(), e);
171 }
172 }
173
174 let ws_token = if auth_token.is_none() {
175 uuid::Uuid::new_v4().to_string()
176 } else {
177 String::new()
178 };
179
180 let state = AppState {
181 agent: Arc::new(agent),
182 cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())),
183 auth_token: auth_token.clone(),
184 ws_token,
185 };
186
187 let cors_layer = if auth_token.is_none() {
188 let allowed_origins = [
189 "http://localhost:3000".parse().unwrap(),
190 "http://127.0.0.1:3000".parse().unwrap(),
191 format!("http://{}:{}", host, port).parse().unwrap(),
192 ];
193 CorsLayer::new()
194 .allow_origin(AllowOrigin::list(allowed_origins))
195 .allow_methods(Any)
196 .allow_headers(Any)
197 } else {
198 CorsLayer::new()
199 .allow_origin(Any)
200 .allow_methods(Any)
201 .allow_headers(Any)
202 };
203
204 let app = Router::new()
205 .route("/", get(serve_index))
206 .route("/session/{session_name}", get(serve_session))
207 .route("/ws", get(websocket_handler))
208 .route("/api/health", get(health_check))
209 .route("/api/sessions", get(list_sessions))
210 .route("/api/sessions/{session_id}", get(get_session))
211 .route("/static/{*path}", get(serve_static))
212 .layer(middleware::from_fn_with_state(
213 state.clone(),
214 auth_middleware,
215 ))
216 .layer(cors_layer)
217 .with_state(state);
218
219 let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
220
221 println!("\n🪿 Starting aster web server");
222 println!(" Provider: {} | Model: {}", provider_name, model);
223 println!(
224 " Working directory: {}",
225 std::env::current_dir()?.display()
226 );
227 println!(" Server: http://{}", addr);
228 println!(" Press Ctrl+C to stop\n");
229
230 if open {
231 let url = format!("http://{}", addr);
232 if let Err(e) = webbrowser::open(&url) {
233 eprintln!("Failed to open browser: {}", e);
234 }
235 }
236
237 let listener = tokio::net::TcpListener::bind(addr).await?;
238 axum::serve(listener, app).await?;
239
240 Ok(())
241}
242
243async fn serve_index(uri: Uri) -> Result<Redirect, (http::StatusCode, String)> {
244 let session = SessionManager::create_session(
245 std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
246 "Web session".to_string(),
247 SessionType::User,
248 )
249 .await
250 .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
251
252 let redirect_url = if let Some(query) = uri.query() {
253 format!("/session/{}?{}", session.id, query)
254 } else {
255 format!("/session/{}", session.id)
256 };
257
258 Ok(Redirect::to(&redirect_url))
259}
260
261async fn serve_session(
262 axum::extract::Path(session_name): axum::extract::Path<String>,
263 State(state): State<AppState>,
264) -> Html<String> {
265 let html = include_str!("../../static/index.html");
266 let html_with_session = html.replace(
267 "<script src=\"/static/script.js\"></script>",
268 &format!(
269 "<script>window.ASTER_SESSION_NAME = '{}'; window.ASTER_WS_TOKEN = '{}';</script>\n <script src=\"/static/script.js\"></script>",
270 session_name,
271 state.ws_token
272 )
273 );
274 Html(html_with_session)
275}
276
277async fn serve_static(axum::extract::Path(path): axum::extract::Path<String>) -> Response {
278 match path.as_str() {
279 "style.css" => (
280 [("content-type", "text/css")],
281 include_str!("../../static/style.css"),
282 )
283 .into_response(),
284 "script.js" => (
285 [("content-type", "application/javascript")],
286 include_str!("../../static/script.js"),
287 )
288 .into_response(),
289 "img/logo_dark.png" => (
290 [("content-type", "image/png")],
291 include_bytes!("../../static/img/logo_dark.png").to_vec(),
292 )
293 .into_response(),
294 "img/logo_light.png" => (
295 [("content-type", "image/png")],
296 include_bytes!("../../static/img/logo_light.png").to_vec(),
297 )
298 .into_response(),
299 _ => (http::StatusCode::NOT_FOUND, "Not found").into_response(),
300 }
301}
302
303async fn health_check() -> Json<serde_json::Value> {
304 Json(serde_json::json!({
305 "status": "ok",
306 "service": "aster-web"
307 }))
308}
309
310async fn list_sessions() -> Json<serde_json::Value> {
311 match SessionManager::list_sessions().await {
312 Ok(sessions) => {
313 let mut session_info = Vec::new();
314
315 for session in sessions {
316 session_info.push(serde_json::json!({
317 "name": session.id,
318 "path": session.id,
319 "description": session.name,
320 "message_count": session.message_count,
321 "working_dir": session.working_dir
322 }));
323 }
324 Json(serde_json::json!({
325 "sessions": session_info
326 }))
327 }
328 Err(e) => Json(serde_json::json!({
329 "error": e.to_string()
330 })),
331 }
332}
333async fn get_session(
334 axum::extract::Path(session_id): axum::extract::Path<String>,
335) -> Json<serde_json::Value> {
336 match SessionManager::get_session(&session_id, true).await {
337 Ok(session) => Json(serde_json::json!({
338 "metadata": session,
339 "messages": session.conversation.unwrap_or_default().messages()
340 })),
341 Err(e) => Json(serde_json::json!({
342 "error": e.to_string()
343 })),
344 }
345}
346
347#[derive(Deserialize)]
348struct WsQuery {
349 token: Option<String>,
350}
351
352async fn websocket_handler(
353 ws: WebSocketUpgrade,
354 State(state): State<AppState>,
355 Query(query): Query<WsQuery>,
356) -> Result<impl IntoResponse, StatusCode> {
357 if state.auth_token.is_none() {
358 let provided_token = query.token.as_deref().unwrap_or("");
359 if provided_token != state.ws_token {
360 tracing::warn!("WebSocket connection rejected: invalid token");
361 return Err(StatusCode::FORBIDDEN);
362 }
363 }
364
365 Ok(ws.on_upgrade(|socket| handle_socket(socket, state)))
366}
367
368async fn handle_socket(socket: WebSocket, state: AppState) {
369 let (sender, mut receiver) = socket.split();
370 let sender = Arc::new(Mutex::new(sender));
371
372 while let Some(msg) = receiver.next().await {
373 if let Ok(msg) = msg {
374 match msg {
375 Message::Text(text) => {
376 match serde_json::from_str::<WebSocketMessage>(&text.to_string()) {
377 Ok(WebSocketMessage::Message {
378 content,
379 session_id,
380 ..
381 }) => {
382 let sender_clone = sender.clone();
383 let agent = state.agent.clone();
384 let session_id_clone = session_id.clone();
385
386 let task_handle = tokio::spawn(async move {
387 let result = process_message_streaming(
388 &agent,
389 session_id_clone,
390 content,
391 sender_clone,
392 )
393 .await;
394
395 if let Err(e) = result {
396 error!("Error processing message: {}", e);
397 }
398 });
399
400 {
401 let mut cancellations = state.cancellations.write().await;
402 cancellations
403 .insert(session_id.clone(), task_handle.abort_handle());
404 }
405
406 let sender_for_abort = sender.clone();
408 let session_id_for_cleanup = session_id.clone();
409 let cancellations_for_cleanup = state.cancellations.clone();
410
411 tokio::spawn(async move {
412 match task_handle.await {
413 Ok(_) => {}
414 Err(e) if e.is_cancelled() => {
415 let mut sender = sender_for_abort.lock().await;
416 let _ = sender
417 .send(Message::Text(
418 serde_json::to_string(
419 &WebSocketMessage::Cancelled {
420 message: "Operation cancelled by user"
421 .to_string(),
422 },
423 )
424 .unwrap()
425 .into(),
426 ))
427 .await;
428 }
429 Err(e) => {
430 error!("Task error: {}", e);
431 }
432 }
433
434 let mut cancellations = cancellations_for_cleanup.write().await;
435 cancellations.remove(&session_id_for_cleanup);
436 });
437 }
438 Ok(WebSocketMessage::Cancel { session_id }) => {
439 let abort_handle = {
441 let mut cancellations = state.cancellations.write().await;
442 cancellations.remove(&session_id)
443 };
444
445 if let Some(handle) = abort_handle {
446 handle.abort();
447
448 let mut sender = sender.lock().await;
450 let _ = sender
451 .send(Message::Text(
452 serde_json::to_string(&WebSocketMessage::Cancelled {
453 message: "Operation cancelled".to_string(),
454 })
455 .unwrap()
456 .into(),
457 ))
458 .await;
459 }
460 }
461 Ok(_) => {
462 }
464 Err(e) => {
465 error!("Failed to parse WebSocket message: {}", e);
466 }
467 }
468 }
469 Message::Close(_) => break,
470 _ => {}
471 }
472 } else {
473 break;
474 }
475 }
476}
477
478async fn process_message_streaming(
479 agent: &Agent,
480 session_id: String,
481 content: String,
482 sender: Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
483) -> Result<()> {
484 use aster::agents::SessionConfig;
485 use aster::conversation::message::MessageContent;
486 use futures::StreamExt;
487
488 let user_message = AsterMessage::user().with_text(content.clone());
489
490 let provider = agent.provider().await;
491 if provider.is_err() {
492 let error_msg = "I'm not properly configured yet. Please configure a provider through the CLI first using `aster configure`.".to_string();
493 let mut sender = sender.lock().await;
494 let _ = sender
495 .send(Message::Text(
496 serde_json::to_string(&WebSocketMessage::Response {
497 content: error_msg,
498 role: "assistant".to_string(),
499 timestamp: chrono::Utc::now().timestamp_millis(),
500 })
501 .unwrap()
502 .into(),
503 ))
504 .await;
505 return Ok(());
506 }
507
508 let session = SessionManager::get_session(&session_id, true).await?;
509 let mut messages = session.conversation.unwrap_or_default();
510 messages.push(user_message.clone());
511
512 let session_config = SessionConfig {
513 id: session.id.clone(),
514 schedule_id: None,
515 max_turns: None,
516 retry_config: None,
517 system_prompt: None,
518 };
519
520 match agent.reply(user_message, session_config, None).await {
521 Ok(mut stream) => {
522 while let Some(result) = stream.next().await {
523 match result {
524 Ok(AgentEvent::Message(message)) => {
525 for content in &message.content {
526 match content {
527 MessageContent::Text(text) => {
528 let mut sender = sender.lock().await;
529 let _ = sender
530 .send(Message::Text(
531 serde_json::to_string(&WebSocketMessage::Response {
532 content: text.text.clone(),
533 role: "assistant".to_string(),
534 timestamp: chrono::Utc::now().timestamp_millis(),
535 })
536 .unwrap()
537 .into(),
538 ))
539 .await;
540 }
541 MessageContent::ToolRequest(req) => {
542 let mut sender = sender.lock().await;
543 if let Ok(tool_call) = &req.tool_call {
544 let _ = sender
545 .send(Message::Text(
546 serde_json::to_string(
547 &WebSocketMessage::ToolRequest {
548 id: req.id.clone(),
549 tool_name: tool_call.name.to_string(),
550 arguments: Value::from(
551 tool_call.arguments.clone(),
552 ),
553 },
554 )
555 .unwrap()
556 .into(),
557 ))
558 .await;
559 }
560 }
561 MessageContent::ToolResponse(_resp) => {}
562 MessageContent::ToolConfirmationRequest(confirmation) => {
563 let mut sender = sender.lock().await;
564 let _ = sender
565 .send(Message::Text(
566 serde_json::to_string(
567 &WebSocketMessage::ToolConfirmation {
568 id: confirmation.id.clone(),
569 tool_name: confirmation
570 .tool_name
571 .to_string()
572 .clone(),
573 arguments: Value::from(
574 confirmation.arguments.clone(),
575 ),
576 needs_confirmation: true,
577 },
578 )
579 .unwrap()
580 .into(),
581 ))
582 .await;
583
584 agent.handle_confirmation(
585 confirmation.id.clone(),
586 aster::permission::PermissionConfirmation {
587 principal_type: aster::permission::permission_confirmation::PrincipalType::Tool,
588 permission: aster::permission::Permission::AllowOnce,
589 }
590 ).await;
591 }
592 MessageContent::Thinking(thinking) => {
593 let mut sender = sender.lock().await;
594 let _ = sender
595 .send(Message::Text(
596 serde_json::to_string(&WebSocketMessage::Thinking {
597 message: thinking.thinking.clone(),
598 })
599 .unwrap()
600 .into(),
601 ))
602 .await;
603 }
604 _ => {}
605 }
606 }
607 }
608 Ok(AgentEvent::HistoryReplaced(_new_messages)) => {
609 tracing::info!("History replaced, compacting happened in reply");
610 }
611 Ok(AgentEvent::McpNotification(_notification)) => {
612 tracing::info!("Received MCP notification in web interface");
613 }
614 Ok(AgentEvent::ModelChange { model, mode }) => {
615 tracing::info!("Model changed to {} in {} mode", model, mode);
616 }
617 Err(e) => {
618 error!("Error in message stream: {}", e);
619 let mut sender = sender.lock().await;
620 let _ = sender
621 .send(Message::Text(
622 serde_json::to_string(&WebSocketMessage::Error {
623 message: format!("Error: {}", e),
624 })
625 .unwrap()
626 .into(),
627 ))
628 .await;
629 break;
630 }
631 }
632 }
633 }
634 Err(e) => {
635 error!("Error calling agent: {}", e);
636 let mut sender = sender.lock().await;
637 let _ = sender
638 .send(Message::Text(
639 serde_json::to_string(&WebSocketMessage::Error {
640 message: format!("Error: {}", e),
641 })
642 .unwrap()
643 .into(),
644 ))
645 .await;
646 }
647 }
648
649 let mut sender = sender.lock().await;
650 let _ = sender
651 .send(Message::Text(
652 serde_json::to_string(&WebSocketMessage::Complete {
653 message: "Response complete".to_string(),
654 })
655 .unwrap()
656 .into(),
657 ))
658 .await;
659
660 Ok(())
661}