Skip to main content

aster_server/routes/
agent.rs

1use crate::routes::errors::ErrorResponse;
2use crate::routes::recipe_utils::{
3    apply_recipe_to_agent, build_recipe_with_parameter_values, load_recipe_by_id, validate_recipe,
4};
5use crate::state::AppState;
6use aster::config::PermissionManager;
7use axum::response::IntoResponse;
8use axum::{
9    extract::{Query, State},
10    http::StatusCode,
11    routing::{get, post},
12    Json, Router,
13};
14
15use aster::agents::ExtensionConfig;
16use aster::config::{AsterMode, Config};
17use aster::model::ModelConfig;
18use aster::prompt_template::render_global_file;
19use aster::providers::create;
20use aster::recipe::Recipe;
21use aster::recipe_deeplink;
22use aster::session::session_manager::SessionType;
23use aster::session::{Session, SessionManager};
24use aster::{
25    agents::{extension::ToolInfo, extension_manager::get_parameter_names},
26    config::permission::PermissionLevel,
27};
28use base64::Engine;
29use rmcp::model::{CallToolRequestParam, Content};
30use serde::{Deserialize, Serialize};
31use serde_json::Value;
32use std::collections::HashMap;
33use std::path::PathBuf;
34use std::sync::atomic::Ordering;
35use std::sync::Arc;
36use tokio_util::sync::CancellationToken;
37use tracing::{error, warn};
38
39#[derive(Deserialize, utoipa::ToSchema)]
40pub struct UpdateFromSessionRequest {
41    session_id: String,
42}
43
44#[derive(Deserialize, utoipa::ToSchema)]
45pub struct UpdateProviderRequest {
46    provider: String,
47    model: Option<String>,
48    session_id: String,
49}
50
51#[derive(Deserialize, utoipa::ToSchema)]
52pub struct GetToolsQuery {
53    extension_name: Option<String>,
54    session_id: String,
55}
56
57#[derive(Deserialize, utoipa::ToSchema)]
58pub struct StartAgentRequest {
59    working_dir: String,
60    #[serde(default)]
61    recipe: Option<Recipe>,
62    #[serde(default)]
63    recipe_id: Option<String>,
64    #[serde(default)]
65    recipe_deeplink: Option<String>,
66}
67
68#[derive(Deserialize, utoipa::ToSchema)]
69pub struct StopAgentRequest {
70    session_id: String,
71}
72
73#[derive(Deserialize, utoipa::ToSchema)]
74pub struct ResumeAgentRequest {
75    session_id: String,
76    load_model_and_extensions: bool,
77}
78
79#[derive(Deserialize, utoipa::ToSchema)]
80pub struct AddExtensionRequest {
81    session_id: String,
82    config: ExtensionConfig,
83}
84
85#[derive(Deserialize, utoipa::ToSchema)]
86pub struct RemoveExtensionRequest {
87    name: String,
88    session_id: String,
89}
90
91#[derive(Deserialize, utoipa::ToSchema)]
92pub struct ReadResourceRequest {
93    session_id: String,
94    extension_name: String,
95    uri: String,
96}
97
98#[derive(Serialize, utoipa::ToSchema)]
99#[serde(rename_all = "camelCase")]
100pub struct ReadResourceResponse {
101    uri: String,
102    #[serde(skip_serializing_if = "Option::is_none")]
103    mime_type: Option<String>,
104    text: String,
105    #[serde(rename = "_meta", skip_serializing_if = "Option::is_none")]
106    meta: Option<serde_json::Map<String, Value>>,
107}
108
109#[derive(Deserialize, utoipa::ToSchema)]
110pub struct CallToolRequest {
111    session_id: String,
112    name: String,
113    arguments: Value,
114}
115
116#[derive(Serialize, utoipa::ToSchema)]
117pub struct CallToolResponse {
118    content: Vec<Content>,
119    structured_content: Option<Value>,
120    is_error: bool,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    _meta: Option<Value>,
123}
124
125#[utoipa::path(
126    post,
127    path = "/agent/start",
128    request_body = StartAgentRequest,
129    responses(
130        (status = 200, description = "Agent started successfully", body = Session),
131        (status = 400, description = "Bad request", body = ErrorResponse),
132        (status = 401, description = "Unauthorized - invalid secret key"),
133        (status = 500, description = "Internal server error", body = ErrorResponse)
134    )
135)]
136async fn start_agent(
137    State(state): State<Arc<AppState>>,
138    Json(payload): Json<StartAgentRequest>,
139) -> Result<Json<Session>, ErrorResponse> {
140    aster::posthog::set_session_context("desktop", false);
141
142    let StartAgentRequest {
143        working_dir,
144        recipe,
145        recipe_id,
146        recipe_deeplink,
147    } = payload;
148
149    let original_recipe = if let Some(deeplink) = recipe_deeplink {
150        match recipe_deeplink::decode(&deeplink) {
151            Ok(recipe) => Some(recipe),
152            Err(err) => {
153                error!("Failed to decode recipe deeplink: {}", err);
154                aster::posthog::emit_error("recipe_deeplink_decode_failed", &err.to_string());
155                return Err(ErrorResponse {
156                    message: err.to_string(),
157                    status: StatusCode::BAD_REQUEST,
158                });
159            }
160        }
161    } else if let Some(id) = recipe_id {
162        match load_recipe_by_id(state.as_ref(), &id).await {
163            Ok(recipe) => Some(recipe),
164            Err(err) => return Err(err),
165        }
166    } else {
167        recipe
168    };
169
170    if let Some(ref recipe) = original_recipe {
171        if let Err(err) = validate_recipe(recipe) {
172            return Err(ErrorResponse {
173                message: err.message,
174                status: err.status,
175            });
176        }
177    }
178
179    let counter = state.session_counter.fetch_add(1, Ordering::SeqCst) + 1;
180    let name = format!("New session {}", counter);
181
182    let mut session =
183        SessionManager::create_session(PathBuf::from(&working_dir), name, SessionType::User)
184            .await
185            .map_err(|err| {
186                error!("Failed to create session: {}", err);
187                aster::posthog::emit_error("session_create_failed", &err.to_string());
188                ErrorResponse {
189                    message: format!("Failed to create session: {}", err),
190                    status: StatusCode::BAD_REQUEST,
191                }
192            })?;
193
194    if let Some(recipe) = original_recipe {
195        SessionManager::update_session(&session.id)
196            .recipe(Some(recipe))
197            .apply()
198            .await
199            .map_err(|err| {
200                error!("Failed to update session with recipe: {}", err);
201                ErrorResponse {
202                    message: format!("Failed to update session with recipe: {}", err),
203                    status: StatusCode::INTERNAL_SERVER_ERROR,
204                }
205            })?;
206
207        session = SessionManager::get_session(&session.id, false)
208            .await
209            .map_err(|err| {
210                error!("Failed to get updated session: {}", err);
211                ErrorResponse {
212                    message: format!("Failed to get updated session: {}", err),
213                    status: StatusCode::INTERNAL_SERVER_ERROR,
214                }
215            })?;
216    }
217
218    Ok(Json(session))
219}
220
221#[utoipa::path(
222    post,
223    path = "/agent/resume",
224    request_body = ResumeAgentRequest,
225    responses(
226        (status = 200, description = "Agent started successfully", body = Session),
227        (status = 400, description = "Bad request - invalid working directory"),
228        (status = 401, description = "Unauthorized - invalid secret key"),
229        (status = 500, description = "Internal server error")
230    )
231)]
232async fn resume_agent(
233    State(state): State<Arc<AppState>>,
234    Json(payload): Json<ResumeAgentRequest>,
235) -> Result<Json<Session>, ErrorResponse> {
236    aster::posthog::set_session_context("desktop", true);
237
238    let session = SessionManager::get_session(&payload.session_id, true)
239        .await
240        .map_err(|err| {
241            error!("Failed to resume session {}: {}", payload.session_id, err);
242            aster::posthog::emit_error("session_resume_failed", &err.to_string());
243            ErrorResponse {
244                message: format!("Failed to resume session: {}", err),
245                status: StatusCode::NOT_FOUND,
246            }
247        })?;
248
249    if payload.load_model_and_extensions {
250        let agent = state
251            .get_agent_for_route(payload.session_id.clone())
252            .await
253            .map_err(|code| ErrorResponse {
254                message: "Failed to get agent for route".into(),
255                status: code,
256            })?;
257
258        let config = Config::global();
259
260        let provider_result = async {
261            let provider_name = session
262                .provider_name
263                .clone()
264                .or_else(|| config.get_aster_provider().ok())
265                .ok_or_else(|| ErrorResponse {
266                    message: "Could not configure agent: missing provider".into(),
267                    status: StatusCode::INTERNAL_SERVER_ERROR,
268                })?;
269
270            let model_config = match session.model_config.clone() {
271                Some(saved_config) => saved_config,
272                None => {
273                    let model_name = config.get_aster_model().map_err(|_| ErrorResponse {
274                        message: "Could not configure agent: missing model".into(),
275                        status: StatusCode::INTERNAL_SERVER_ERROR,
276                    })?;
277                    ModelConfig::new(&model_name).map_err(|e| ErrorResponse {
278                        message: format!("Could not configure agent: invalid model {}", e),
279                        status: StatusCode::INTERNAL_SERVER_ERROR,
280                    })?
281                }
282            };
283
284            let provider =
285                create(&provider_name, model_config)
286                    .await
287                    .map_err(|e| ErrorResponse {
288                        message: format!("Could not create provider: {}", e),
289                        status: StatusCode::INTERNAL_SERVER_ERROR,
290                    })?;
291
292            agent
293                .update_provider(provider, &payload.session_id)
294                .await
295                .map_err(|e| ErrorResponse {
296                    message: format!("Could not configure agent: {}", e),
297                    status: StatusCode::INTERNAL_SERVER_ERROR,
298                })
299        };
300
301        let extensions_result = async {
302            let enabled_configs = aster::config::get_enabled_extensions();
303            let agent_clone = agent.clone();
304
305            let extension_futures = enabled_configs
306                .into_iter()
307                .map(|config| {
308                    let config_clone = config.clone();
309                    let agent_ref = agent_clone.clone();
310
311                    async move {
312                        if let Err(e) = agent_ref.add_extension(config_clone.clone()).await {
313                            warn!("Failed to load extension {}: {}", config_clone.name(), e);
314                            aster::posthog::emit_error(
315                                "extension_load_failed",
316                                &format!("{}: {}", config_clone.name(), e),
317                            );
318                        }
319                        Ok::<_, ErrorResponse>(())
320                    }
321                })
322                .collect::<Vec<_>>();
323
324            futures::future::join_all(extension_futures).await;
325            Ok::<(), ErrorResponse>(()) // Fixed type annotation
326        };
327
328        let (provider_result, _) = tokio::join!(provider_result, extensions_result);
329        provider_result?;
330    }
331
332    Ok(Json(session))
333}
334
335#[utoipa::path(
336    post,
337    path = "/agent/update_from_session",
338    request_body = UpdateFromSessionRequest,
339    responses(
340        (status = 200, description = "Update agent from session data successfully"),
341        (status = 401, description = "Unauthorized - invalid secret key"),
342        (status = 424, description = "Agent not initialized"),
343    ),
344)]
345async fn update_from_session(
346    State(state): State<Arc<AppState>>,
347    Json(payload): Json<UpdateFromSessionRequest>,
348) -> Result<StatusCode, ErrorResponse> {
349    let agent = state
350        .get_agent_for_route(payload.session_id.clone())
351        .await
352        .map_err(|status| ErrorResponse {
353            message: format!("Failed to get agent: {}", status),
354            status,
355        })?;
356    let session = SessionManager::get_session(&payload.session_id, false)
357        .await
358        .map_err(|err| ErrorResponse {
359            message: format!("Failed to get session: {}", err),
360            status: StatusCode::INTERNAL_SERVER_ERROR,
361        })?;
362    let context: HashMap<&str, Value> = HashMap::new();
363    let desktop_prompt =
364        render_global_file("desktop_prompt.md", &context).expect("Prompt should render");
365    let mut update_prompt = desktop_prompt;
366    if let Some(recipe) = session.recipe {
367        match build_recipe_with_parameter_values(
368            &recipe,
369            session.user_recipe_values.unwrap_or_default(),
370        )
371        .await
372        {
373            Ok(Some(recipe)) => {
374                if let Some(prompt) = apply_recipe_to_agent(&agent, &recipe, true).await {
375                    update_prompt = prompt;
376                }
377            }
378            Ok(None) => {
379                // Recipe has missing parameters - use default prompt
380            }
381            Err(e) => {
382                return Err(ErrorResponse {
383                    message: e.to_string(),
384                    status: StatusCode::INTERNAL_SERVER_ERROR,
385                });
386            }
387        }
388    }
389    agent.extend_system_prompt(update_prompt).await;
390
391    Ok(StatusCode::OK)
392}
393
394#[utoipa::path(
395    get,
396    path = "/agent/tools",
397    params(
398        ("extension_name" = Option<String>, Query, description = "Optional extension name to filter tools"),
399        ("session_id" = String, Query, description = "Required session ID to scope tools to a specific session")
400    ),
401    responses(
402        (status = 200, description = "Tools retrieved successfully", body = Vec<ToolInfo>),
403        (status = 401, description = "Unauthorized - invalid secret key"),
404        (status = 424, description = "Agent not initialized"),
405        (status = 500, description = "Internal server error")
406    )
407)]
408async fn get_tools(
409    State(state): State<Arc<AppState>>,
410    Query(query): Query<GetToolsQuery>,
411) -> Result<Json<Vec<ToolInfo>>, StatusCode> {
412    let config = Config::global();
413    let aster_mode = config.get_aster_mode().unwrap_or(AsterMode::Auto);
414    let agent = state.get_agent_for_route(query.session_id).await?;
415    let permission_manager = PermissionManager::default();
416
417    let mut tools: Vec<ToolInfo> = agent
418        .list_tools(query.extension_name)
419        .await
420        .into_iter()
421        .map(|tool| {
422            let permission = permission_manager
423                .get_user_permission(&tool.name)
424                .or_else(|| {
425                    if aster_mode == AsterMode::SmartApprove {
426                        permission_manager.get_smart_approve_permission(&tool.name)
427                    } else if aster_mode == AsterMode::Approve {
428                        Some(PermissionLevel::AskBefore)
429                    } else {
430                        None
431                    }
432                });
433
434            ToolInfo::new(
435                &tool.name,
436                tool.description
437                    .as_ref()
438                    .map(|d| d.as_ref())
439                    .unwrap_or_default(),
440                get_parameter_names(&tool),
441                permission,
442            )
443        })
444        .collect::<Vec<ToolInfo>>();
445    tools.sort_by(|a, b| a.name.cmp(&b.name));
446
447    Ok(Json(tools))
448}
449
450#[utoipa::path(
451    post,
452    path = "/agent/update_provider",
453    request_body = UpdateProviderRequest,
454    responses(
455        (status = 200, description = "Provider updated successfully"),
456        (status = 400, description = "Bad request - missing or invalid parameters"),
457        (status = 401, description = "Unauthorized - invalid secret key"),
458        (status = 424, description = "Agent not initialized"),
459        (status = 500, description = "Internal server error")
460    )
461)]
462async fn update_agent_provider(
463    State(state): State<Arc<AppState>>,
464    Json(payload): Json<UpdateProviderRequest>,
465) -> Result<(), impl IntoResponse> {
466    let agent = state
467        .get_agent_for_route(payload.session_id.clone())
468        .await
469        .map_err(|e| (e, "No agent for session id".to_owned()))?;
470
471    let config = Config::global();
472    let model = match payload.model.or_else(|| config.get_aster_model().ok()) {
473        Some(m) => m,
474        None => {
475            return Err((StatusCode::BAD_REQUEST, "No model specified".to_owned()));
476        }
477    };
478
479    let model_config = ModelConfig::new(&model).map_err(|e| {
480        (
481            StatusCode::BAD_REQUEST,
482            format!("Invalid model config: {}", e),
483        )
484    })?;
485
486    let new_provider = create(&payload.provider, model_config).await.map_err(|e| {
487        (
488            StatusCode::BAD_REQUEST,
489            format!("Failed to create {} provider: {}", &payload.provider, e),
490        )
491    })?;
492
493    agent
494        .update_provider(new_provider, &payload.session_id)
495        .await
496        .map_err(|e| {
497            (
498                StatusCode::INTERNAL_SERVER_ERROR,
499                format!("Failed to update provider: {}", e),
500            )
501        })?;
502
503    Ok(())
504}
505
506#[utoipa::path(
507    post,
508    path = "/agent/add_extension",
509    request_body = AddExtensionRequest,
510    responses(
511        (status = 200, description = "Extension added", body = String),
512        (status = 401, description = "Unauthorized - invalid secret key"),
513        (status = 424, description = "Agent not initialized"),
514        (status = 500, description = "Internal server error")
515    )
516)]
517async fn agent_add_extension(
518    State(state): State<Arc<AppState>>,
519    Json(request): Json<AddExtensionRequest>,
520) -> Result<StatusCode, ErrorResponse> {
521    let extension_name = request.config.name();
522    let agent = state.get_agent(request.session_id).await?;
523    agent.add_extension(request.config).await.map_err(|e| {
524        aster::posthog::emit_error(
525            "extension_add_failed",
526            &format!("{}: {}", extension_name, e),
527        );
528        ErrorResponse::internal(format!("Failed to add extension: {}", e))
529    })?;
530    Ok(StatusCode::OK)
531}
532
533#[utoipa::path(
534    post,
535    path = "/agent/remove_extension",
536    request_body = RemoveExtensionRequest,
537    responses(
538        (status = 200, description = "Extension removed", body = String),
539        (status = 401, description = "Unauthorized - invalid secret key"),
540        (status = 424, description = "Agent not initialized"),
541        (status = 500, description = "Internal server error")
542    )
543)]
544async fn agent_remove_extension(
545    State(state): State<Arc<AppState>>,
546    Json(request): Json<RemoveExtensionRequest>,
547) -> Result<StatusCode, ErrorResponse> {
548    let agent = state.get_agent(request.session_id).await?;
549    agent.remove_extension(&request.name).await?;
550    Ok(StatusCode::OK)
551}
552
553#[utoipa::path(
554    post,
555    path = "/agent/stop",
556    request_body = StopAgentRequest,
557    responses(
558        (status = 200, description = "Agent stopped successfully", body = String),
559        (status = 401, description = "Unauthorized - invalid secret key"),
560        (status = 404, description = "Session not found"),
561        (status = 500, description = "Internal server error")
562    )
563)]
564async fn stop_agent(
565    State(state): State<Arc<AppState>>,
566    Json(payload): Json<StopAgentRequest>,
567) -> Result<StatusCode, ErrorResponse> {
568    let session_id = payload.session_id;
569    state
570        .agent_manager
571        .remove_session(&session_id)
572        .await
573        .map_err(|e| ErrorResponse {
574            message: format!("Failed to stop agent for session {}: {}", session_id, e),
575            status: StatusCode::NOT_FOUND,
576        })?;
577
578    Ok(StatusCode::OK)
579}
580
581#[utoipa::path(
582    post,
583    path = "/agent/read_resource",
584    request_body = ReadResourceRequest,
585    responses(
586        (status = 200, description = "Resource read successfully", body = ReadResourceResponse),
587        (status = 401, description = "Unauthorized - invalid secret key"),
588        (status = 424, description = "Agent not initialized"),
589        (status = 404, description = "Resource not found"),
590        (status = 500, description = "Internal server error")
591    )
592)]
593async fn read_resource(
594    State(state): State<Arc<AppState>>,
595    Json(payload): Json<ReadResourceRequest>,
596) -> Result<Json<ReadResourceResponse>, StatusCode> {
597    use rmcp::model::ResourceContents;
598
599    let agent = state
600        .get_agent_for_route(payload.session_id.clone())
601        .await?;
602
603    let read_result = agent
604        .extension_manager
605        .read_resource(
606            &payload.uri,
607            &payload.extension_name,
608            CancellationToken::default(),
609        )
610        .await
611        .map_err(|_e| StatusCode::INTERNAL_SERVER_ERROR)?;
612
613    let content = read_result
614        .contents
615        .into_iter()
616        .next()
617        .ok_or(StatusCode::NOT_FOUND)?;
618
619    let (uri, mime_type, text, meta) = match content {
620        ResourceContents::TextResourceContents {
621            uri,
622            mime_type,
623            text,
624            meta,
625        } => (uri, mime_type, text, meta),
626        ResourceContents::BlobResourceContents {
627            uri,
628            mime_type,
629            blob,
630            meta,
631        } => {
632            let decoded = match base64::engine::general_purpose::STANDARD.decode(&blob) {
633                Ok(bytes) => {
634                    String::from_utf8(bytes).map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?
635                }
636                Err(_) => return Err(StatusCode::INTERNAL_SERVER_ERROR),
637            };
638            (uri, mime_type, decoded, meta)
639        }
640    };
641
642    let meta_map = meta.map(|m| m.0);
643
644    Ok(Json(ReadResourceResponse {
645        uri,
646        mime_type,
647        text,
648        meta: meta_map,
649    }))
650}
651
652#[utoipa::path(
653    post,
654    path = "/agent/call_tool",
655    request_body = CallToolRequest,
656    responses(
657        (status = 200, description = "Resource read successfully", body = CallToolResponse),
658        (status = 401, description = "Unauthorized - invalid secret key"),
659        (status = 424, description = "Agent not initialized"),
660        (status = 404, description = "Resource not found"),
661        (status = 500, description = "Internal server error")
662    )
663)]
664async fn call_tool(
665    State(state): State<Arc<AppState>>,
666    Json(payload): Json<CallToolRequest>,
667) -> Result<Json<CallToolResponse>, StatusCode> {
668    let agent = state
669        .get_agent_for_route(payload.session_id.clone())
670        .await?;
671
672    let arguments = match payload.arguments {
673        Value::Object(map) => Some(map),
674        _ => None,
675    };
676
677    let tool_call = CallToolRequestParam {
678        name: payload.name.into(),
679        arguments,
680    };
681
682    let tool_result = agent
683        .extension_manager
684        .dispatch_tool_call(tool_call, CancellationToken::default())
685        .await
686        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
687
688    let result = tool_result
689        .result
690        .await
691        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
692
693    Ok(Json(CallToolResponse {
694        content: result.content,
695        structured_content: result.structured_content,
696        is_error: result.is_error.unwrap_or(false),
697        _meta: result.meta.and_then(|m| serde_json::to_value(m).ok()),
698    }))
699}
700
701pub fn routes(state: Arc<AppState>) -> Router {
702    Router::new()
703        .route("/agent/start", post(start_agent))
704        .route("/agent/resume", post(resume_agent))
705        .route("/agent/tools", get(get_tools))
706        .route("/agent/read_resource", post(read_resource))
707        .route("/agent/call_tool", post(call_tool))
708        .route("/agent/update_provider", post(update_agent_provider))
709        .route("/agent/update_from_session", post(update_from_session))
710        .route("/agent/add_extension", post(agent_add_extension))
711        .route("/agent/remove_extension", post(agent_remove_extension))
712        .route("/agent/stop", post(stop_agent))
713        .with_state(state)
714}