Skip to main content

aster_cli/commands/
web.rs

1use anyhow::Result;
2use aster::agents::{Agent, AgentEvent};
3use aster::conversation::message::Message as AsterMessage;
4use aster::session::session_manager::SessionType;
5use aster::session::SessionManager;
6use axum::response::Redirect;
7use axum::{
8    extract::{
9        ws::{Message, WebSocket, WebSocketUpgrade},
10        Query, Request, State,
11    },
12    http::{StatusCode, Uri},
13    middleware::{self, Next},
14    response::{Html, IntoResponse, Response},
15    routing::get,
16    Json, Router,
17};
18use base64::Engine;
19use futures::{sink::SinkExt, stream::StreamExt};
20use serde::{Deserialize, Serialize};
21use serde_json::Value;
22use std::{net::SocketAddr, sync::Arc};
23use tokio::sync::{Mutex, RwLock};
24use tower_http::cors::{AllowOrigin, Any, CorsLayer};
25use tracing::error;
26use webbrowser;
27
28type CancellationStore = Arc<RwLock<std::collections::HashMap<String, tokio::task::AbortHandle>>>;
29
30#[derive(Clone)]
31struct AppState {
32    agent: Arc<Agent>,
33    cancellations: CancellationStore,
34    auth_token: Option<String>,
35    ws_token: String,
36}
37
38#[derive(Serialize, Deserialize)]
39#[serde(tag = "type")]
40enum WebSocketMessage {
41    #[serde(rename = "message")]
42    Message {
43        content: String,
44        session_id: String,
45        timestamp: i64,
46    },
47    #[serde(rename = "cancel")]
48    Cancel { session_id: String },
49    #[serde(rename = "response")]
50    Response {
51        content: String,
52        role: String,
53        timestamp: i64,
54    },
55    #[serde(rename = "tool_request")]
56    ToolRequest {
57        id: String,
58        tool_name: String,
59        arguments: serde_json::Value,
60    },
61    #[serde(rename = "tool_response")]
62    ToolResponse {
63        id: String,
64        result: serde_json::Value,
65        is_error: bool,
66    },
67    #[serde(rename = "tool_confirmation")]
68    ToolConfirmation {
69        id: String,
70        tool_name: String,
71        arguments: serde_json::Value,
72        needs_confirmation: bool,
73    },
74    #[serde(rename = "error")]
75    Error { message: String },
76    #[serde(rename = "thinking")]
77    Thinking { message: String },
78    #[serde(rename = "context_exceeded")]
79    ContextExceeded { message: String },
80    #[serde(rename = "cancelled")]
81    Cancelled { message: String },
82    #[serde(rename = "complete")]
83    Complete { message: String },
84}
85
86async fn auth_middleware(
87    State(state): State<AppState>,
88    req: Request,
89    next: Next,
90) -> Result<Response, StatusCode> {
91    if req.uri().path() == "/api/health" {
92        return Ok(next.run(req).await);
93    }
94
95    let Some(ref expected_token) = state.auth_token else {
96        return Ok(next.run(req).await);
97    };
98
99    if let Some(auth_header) = req.headers().get("authorization") {
100        if let Ok(auth_str) = auth_header.to_str() {
101            if let Some(token) = auth_str.strip_prefix("Bearer ") {
102                if token == expected_token {
103                    return Ok(next.run(req).await);
104                }
105            }
106
107            if let Some(basic_token) = auth_str.strip_prefix("Basic ") {
108                if let Ok(decoded) = base64::engine::general_purpose::STANDARD.decode(basic_token) {
109                    if let Ok(credentials) = String::from_utf8(decoded) {
110                        if credentials.ends_with(expected_token) {
111                            return Ok(next.run(req).await);
112                        }
113                    }
114                }
115            }
116        }
117    }
118
119    let mut response = Response::new("Authentication required".into());
120    *response.status_mut() = StatusCode::UNAUTHORIZED;
121    response.headers_mut().insert(
122        "WWW-Authenticate",
123        "Basic realm=\"Aster Web Interface\"".parse().unwrap(),
124    );
125    Ok(response)
126}
127
128pub async fn handle_web(
129    port: u16,
130    host: String,
131    open: bool,
132    auth_token: Option<String>,
133) -> Result<()> {
134    crate::logging::setup_logging(Some("aster-web"), None)?;
135
136    let config = aster::config::Config::global();
137
138    let provider_name: String = match config.get_aster_provider() {
139        Ok(p) => p,
140        Err(_) => {
141            eprintln!("No provider configured. Run 'aster configure' first");
142            std::process::exit(1);
143        }
144    };
145
146    let model: String = match config.get_aster_model() {
147        Ok(m) => m,
148        Err(_) => {
149            eprintln!("No model configured. Run 'aster configure' first");
150            std::process::exit(1);
151        }
152    };
153
154    let model_config = aster::model::ModelConfig::new(&model)?;
155
156    let init_session = SessionManager::create_session(
157        std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
158        "Web Agent Initialization".to_string(),
159        SessionType::Hidden,
160    )
161    .await?;
162
163    let agent = Agent::new();
164    let provider = aster::providers::create(&provider_name, model_config).await?;
165    agent.update_provider(provider, &init_session.id).await?;
166
167    let enabled_configs = aster::config::get_enabled_extensions();
168    for config in enabled_configs {
169        if let Err(e) = agent.add_extension(config.clone()).await {
170            eprintln!("Warning: Failed to load extension {}: {}", config.name(), e);
171        }
172    }
173
174    let ws_token = if auth_token.is_none() {
175        uuid::Uuid::new_v4().to_string()
176    } else {
177        String::new()
178    };
179
180    let state = AppState {
181        agent: Arc::new(agent),
182        cancellations: Arc::new(RwLock::new(std::collections::HashMap::new())),
183        auth_token: auth_token.clone(),
184        ws_token,
185    };
186
187    let cors_layer = if auth_token.is_none() {
188        let allowed_origins = [
189            "http://localhost:3000".parse().unwrap(),
190            "http://127.0.0.1:3000".parse().unwrap(),
191            format!("http://{}:{}", host, port).parse().unwrap(),
192        ];
193        CorsLayer::new()
194            .allow_origin(AllowOrigin::list(allowed_origins))
195            .allow_methods(Any)
196            .allow_headers(Any)
197    } else {
198        CorsLayer::new()
199            .allow_origin(Any)
200            .allow_methods(Any)
201            .allow_headers(Any)
202    };
203
204    let app = Router::new()
205        .route("/", get(serve_index))
206        .route("/session/{session_name}", get(serve_session))
207        .route("/ws", get(websocket_handler))
208        .route("/api/health", get(health_check))
209        .route("/api/sessions", get(list_sessions))
210        .route("/api/sessions/{session_id}", get(get_session))
211        .route("/static/{*path}", get(serve_static))
212        .layer(middleware::from_fn_with_state(
213            state.clone(),
214            auth_middleware,
215        ))
216        .layer(cors_layer)
217        .with_state(state);
218
219    let addr: SocketAddr = format!("{}:{}", host, port).parse()?;
220
221    println!("\n🪿 Starting aster web server");
222    println!("   Provider: {} | Model: {}", provider_name, model);
223    println!(
224        "   Working directory: {}",
225        std::env::current_dir()?.display()
226    );
227    println!("   Server: http://{}", addr);
228    println!("   Press Ctrl+C to stop\n");
229
230    if open {
231        let url = format!("http://{}", addr);
232        if let Err(e) = webbrowser::open(&url) {
233            eprintln!("Failed to open browser: {}", e);
234        }
235    }
236
237    let listener = tokio::net::TcpListener::bind(addr).await?;
238    axum::serve(listener, app).await?;
239
240    Ok(())
241}
242
243async fn serve_index(uri: Uri) -> Result<Redirect, (http::StatusCode, String)> {
244    let session = SessionManager::create_session(
245        std::env::current_dir().unwrap_or_else(|_| std::path::PathBuf::from(".")),
246        "Web session".to_string(),
247        SessionType::User,
248    )
249    .await
250    .map_err(|err| (http::StatusCode::INTERNAL_SERVER_ERROR, err.to_string()))?;
251
252    let redirect_url = if let Some(query) = uri.query() {
253        format!("/session/{}?{}", session.id, query)
254    } else {
255        format!("/session/{}", session.id)
256    };
257
258    Ok(Redirect::to(&redirect_url))
259}
260
261async fn serve_session(
262    axum::extract::Path(session_name): axum::extract::Path<String>,
263    State(state): State<AppState>,
264) -> Html<String> {
265    let html = include_str!("../../static/index.html");
266    let html_with_session = html.replace(
267        "<script src=\"/static/script.js\"></script>",
268        &format!(
269            "<script>window.ASTER_SESSION_NAME = '{}'; window.ASTER_WS_TOKEN = '{}';</script>\n    <script src=\"/static/script.js\"></script>",
270            session_name,
271            state.ws_token
272        )
273    );
274    Html(html_with_session)
275}
276
277async fn serve_static(axum::extract::Path(path): axum::extract::Path<String>) -> Response {
278    match path.as_str() {
279        "style.css" => (
280            [("content-type", "text/css")],
281            include_str!("../../static/style.css"),
282        )
283            .into_response(),
284        "script.js" => (
285            [("content-type", "application/javascript")],
286            include_str!("../../static/script.js"),
287        )
288            .into_response(),
289        "img/logo_dark.png" => (
290            [("content-type", "image/png")],
291            include_bytes!("../../static/img/logo_dark.png").to_vec(),
292        )
293            .into_response(),
294        "img/logo_light.png" => (
295            [("content-type", "image/png")],
296            include_bytes!("../../static/img/logo_light.png").to_vec(),
297        )
298            .into_response(),
299        _ => (http::StatusCode::NOT_FOUND, "Not found").into_response(),
300    }
301}
302
303async fn health_check() -> Json<serde_json::Value> {
304    Json(serde_json::json!({
305        "status": "ok",
306        "service": "aster-web"
307    }))
308}
309
310async fn list_sessions() -> Json<serde_json::Value> {
311    match SessionManager::list_sessions().await {
312        Ok(sessions) => {
313            let mut session_info = Vec::new();
314
315            for session in sessions {
316                session_info.push(serde_json::json!({
317                    "name": session.id,
318                    "path": session.id,
319                    "description": session.name,
320                    "message_count": session.message_count,
321                    "working_dir": session.working_dir
322                }));
323            }
324            Json(serde_json::json!({
325                "sessions": session_info
326            }))
327        }
328        Err(e) => Json(serde_json::json!({
329            "error": e.to_string()
330        })),
331    }
332}
333async fn get_session(
334    axum::extract::Path(session_id): axum::extract::Path<String>,
335) -> Json<serde_json::Value> {
336    match SessionManager::get_session(&session_id, true).await {
337        Ok(session) => Json(serde_json::json!({
338            "metadata": session,
339            "messages": session.conversation.unwrap_or_default().messages()
340        })),
341        Err(e) => Json(serde_json::json!({
342            "error": e.to_string()
343        })),
344    }
345}
346
347#[derive(Deserialize)]
348struct WsQuery {
349    token: Option<String>,
350}
351
352async fn websocket_handler(
353    ws: WebSocketUpgrade,
354    State(state): State<AppState>,
355    Query(query): Query<WsQuery>,
356) -> Result<impl IntoResponse, StatusCode> {
357    if state.auth_token.is_none() {
358        let provided_token = query.token.as_deref().unwrap_or("");
359        if provided_token != state.ws_token {
360            tracing::warn!("WebSocket connection rejected: invalid token");
361            return Err(StatusCode::FORBIDDEN);
362        }
363    }
364
365    Ok(ws.on_upgrade(|socket| handle_socket(socket, state)))
366}
367
368async fn handle_socket(socket: WebSocket, state: AppState) {
369    let (sender, mut receiver) = socket.split();
370    let sender = Arc::new(Mutex::new(sender));
371
372    while let Some(msg) = receiver.next().await {
373        if let Ok(msg) = msg {
374            match msg {
375                Message::Text(text) => {
376                    match serde_json::from_str::<WebSocketMessage>(&text.to_string()) {
377                        Ok(WebSocketMessage::Message {
378                            content,
379                            session_id,
380                            ..
381                        }) => {
382                            let sender_clone = sender.clone();
383                            let agent = state.agent.clone();
384                            let session_id_clone = session_id.clone();
385
386                            let task_handle = tokio::spawn(async move {
387                                let result = process_message_streaming(
388                                    &agent,
389                                    session_id_clone,
390                                    content,
391                                    sender_clone,
392                                )
393                                .await;
394
395                                if let Err(e) = result {
396                                    error!("Error processing message: {}", e);
397                                }
398                            });
399
400                            {
401                                let mut cancellations = state.cancellations.write().await;
402                                cancellations
403                                    .insert(session_id.clone(), task_handle.abort_handle());
404                            }
405
406                            // Handle task completion and cleanup
407                            let sender_for_abort = sender.clone();
408                            let session_id_for_cleanup = session_id.clone();
409                            let cancellations_for_cleanup = state.cancellations.clone();
410
411                            tokio::spawn(async move {
412                                match task_handle.await {
413                                    Ok(_) => {}
414                                    Err(e) if e.is_cancelled() => {
415                                        let mut sender = sender_for_abort.lock().await;
416                                        let _ = sender
417                                            .send(Message::Text(
418                                                serde_json::to_string(
419                                                    &WebSocketMessage::Cancelled {
420                                                        message: "Operation cancelled by user"
421                                                            .to_string(),
422                                                    },
423                                                )
424                                                .unwrap()
425                                                .into(),
426                                            ))
427                                            .await;
428                                    }
429                                    Err(e) => {
430                                        error!("Task error: {}", e);
431                                    }
432                                }
433
434                                let mut cancellations = cancellations_for_cleanup.write().await;
435                                cancellations.remove(&session_id_for_cleanup);
436                            });
437                        }
438                        Ok(WebSocketMessage::Cancel { session_id }) => {
439                            // Cancel the active operation for this session
440                            let abort_handle = {
441                                let mut cancellations = state.cancellations.write().await;
442                                cancellations.remove(&session_id)
443                            };
444
445                            if let Some(handle) = abort_handle {
446                                handle.abort();
447
448                                // Send cancellation confirmation
449                                let mut sender = sender.lock().await;
450                                let _ = sender
451                                    .send(Message::Text(
452                                        serde_json::to_string(&WebSocketMessage::Cancelled {
453                                            message: "Operation cancelled".to_string(),
454                                        })
455                                        .unwrap()
456                                        .into(),
457                                    ))
458                                    .await;
459                            }
460                        }
461                        Ok(_) => {
462                            // Ignore other message types
463                        }
464                        Err(e) => {
465                            error!("Failed to parse WebSocket message: {}", e);
466                        }
467                    }
468                }
469                Message::Close(_) => break,
470                _ => {}
471            }
472        } else {
473            break;
474        }
475    }
476}
477
478async fn process_message_streaming(
479    agent: &Agent,
480    session_id: String,
481    content: String,
482    sender: Arc<Mutex<futures::stream::SplitSink<WebSocket, Message>>>,
483) -> Result<()> {
484    use aster::agents::SessionConfig;
485    use aster::conversation::message::MessageContent;
486    use futures::StreamExt;
487
488    let user_message = AsterMessage::user().with_text(content.clone());
489
490    let provider = agent.provider().await;
491    if provider.is_err() {
492        let error_msg = "I'm not properly configured yet. Please configure a provider through the CLI first using `aster configure`.".to_string();
493        let mut sender = sender.lock().await;
494        let _ = sender
495            .send(Message::Text(
496                serde_json::to_string(&WebSocketMessage::Response {
497                    content: error_msg,
498                    role: "assistant".to_string(),
499                    timestamp: chrono::Utc::now().timestamp_millis(),
500                })
501                .unwrap()
502                .into(),
503            ))
504            .await;
505        return Ok(());
506    }
507
508    let session = SessionManager::get_session(&session_id, true).await?;
509    let mut messages = session.conversation.unwrap_or_default();
510    messages.push(user_message.clone());
511
512    let session_config = SessionConfig {
513        id: session.id.clone(),
514        schedule_id: None,
515        max_turns: None,
516        retry_config: None,
517        system_prompt: None,
518    };
519
520    match agent.reply(user_message, session_config, None).await {
521        Ok(mut stream) => {
522            while let Some(result) = stream.next().await {
523                match result {
524                    Ok(AgentEvent::Message(message)) => {
525                        for content in &message.content {
526                            match content {
527                                MessageContent::Text(text) => {
528                                    let mut sender = sender.lock().await;
529                                    let _ = sender
530                                        .send(Message::Text(
531                                            serde_json::to_string(&WebSocketMessage::Response {
532                                                content: text.text.clone(),
533                                                role: "assistant".to_string(),
534                                                timestamp: chrono::Utc::now().timestamp_millis(),
535                                            })
536                                            .unwrap()
537                                            .into(),
538                                        ))
539                                        .await;
540                                }
541                                MessageContent::ToolRequest(req) => {
542                                    let mut sender = sender.lock().await;
543                                    if let Ok(tool_call) = &req.tool_call {
544                                        let _ = sender
545                                            .send(Message::Text(
546                                                serde_json::to_string(
547                                                    &WebSocketMessage::ToolRequest {
548                                                        id: req.id.clone(),
549                                                        tool_name: tool_call.name.to_string(),
550                                                        arguments: Value::from(
551                                                            tool_call.arguments.clone(),
552                                                        ),
553                                                    },
554                                                )
555                                                .unwrap()
556                                                .into(),
557                                            ))
558                                            .await;
559                                    }
560                                }
561                                MessageContent::ToolResponse(_resp) => {}
562                                MessageContent::ToolConfirmationRequest(confirmation) => {
563                                    let mut sender = sender.lock().await;
564                                    let _ = sender
565                                        .send(Message::Text(
566                                            serde_json::to_string(
567                                                &WebSocketMessage::ToolConfirmation {
568                                                    id: confirmation.id.clone(),
569                                                    tool_name: confirmation
570                                                        .tool_name
571                                                        .to_string()
572                                                        .clone(),
573                                                    arguments: Value::from(
574                                                        confirmation.arguments.clone(),
575                                                    ),
576                                                    needs_confirmation: true,
577                                                },
578                                            )
579                                            .unwrap()
580                                            .into(),
581                                        ))
582                                        .await;
583
584                                    agent.handle_confirmation(
585                                        confirmation.id.clone(),
586                                        aster::permission::PermissionConfirmation {
587                                            principal_type: aster::permission::permission_confirmation::PrincipalType::Tool,
588                                            permission: aster::permission::Permission::AllowOnce,
589                                        }
590                                    ).await;
591                                }
592                                MessageContent::Thinking(thinking) => {
593                                    let mut sender = sender.lock().await;
594                                    let _ = sender
595                                        .send(Message::Text(
596                                            serde_json::to_string(&WebSocketMessage::Thinking {
597                                                message: thinking.thinking.clone(),
598                                            })
599                                            .unwrap()
600                                            .into(),
601                                        ))
602                                        .await;
603                                }
604                                _ => {}
605                            }
606                        }
607                    }
608                    Ok(AgentEvent::HistoryReplaced(_new_messages)) => {
609                        tracing::info!("History replaced, compacting happened in reply");
610                    }
611                    Ok(AgentEvent::McpNotification(_notification)) => {
612                        tracing::info!("Received MCP notification in web interface");
613                    }
614                    Ok(AgentEvent::ModelChange { model, mode }) => {
615                        tracing::info!("Model changed to {} in {} mode", model, mode);
616                    }
617                    Err(e) => {
618                        error!("Error in message stream: {}", e);
619                        let mut sender = sender.lock().await;
620                        let _ = sender
621                            .send(Message::Text(
622                                serde_json::to_string(&WebSocketMessage::Error {
623                                    message: format!("Error: {}", e),
624                                })
625                                .unwrap()
626                                .into(),
627                            ))
628                            .await;
629                        break;
630                    }
631                }
632            }
633        }
634        Err(e) => {
635            error!("Error calling agent: {}", e);
636            let mut sender = sender.lock().await;
637            let _ = sender
638                .send(Message::Text(
639                    serde_json::to_string(&WebSocketMessage::Error {
640                        message: format!("Error: {}", e),
641                    })
642                    .unwrap()
643                    .into(),
644                ))
645                .await;
646        }
647    }
648
649    let mut sender = sender.lock().await;
650    let _ = sender
651        .send(Message::Text(
652            serde_json::to_string(&WebSocketMessage::Complete {
653                message: "Response complete".to_string(),
654            })
655            .unwrap()
656            .into(),
657        ))
658        .await;
659
660    Ok(())
661}