Skip to main content

adk_server/rest/controllers/
artifacts.rs

1use crate::ServerConfig;
2use adk_artifact::{ListRequest, LoadRequest};
3use axum::{
4    Extension, Json,
5    body::Body,
6    extract::{Path, State},
7    http::{StatusCode, header},
8    response::IntoResponse,
9};
10
11#[derive(Clone)]
12pub struct ArtifactsController {
13    config: ServerConfig,
14}
15
16impl ArtifactsController {
17    pub fn new(config: ServerConfig) -> Self {
18        Self { config }
19    }
20}
21
22fn authorize_user_id(
23    request_context: &Option<adk_core::RequestContext>,
24    user_id: &str,
25) -> Result<String, StatusCode> {
26    match request_context {
27        Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
28        Some(context) => Ok(context.user_id.clone()),
29        None => Ok(user_id.to_string()),
30    }
31}
32
33pub async fn list_artifacts(
34    State(controller): State<ArtifactsController>,
35    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
36    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
37) -> Result<Json<Vec<String>>, StatusCode> {
38    let user_id = authorize_user_id(&request_context, &user_id)?;
39
40    if let Some(service) = &controller.config.artifact_service {
41        let resp =
42            service.list(ListRequest { app_name, user_id, session_id }).await.map_err(|e| {
43                tracing::error!(error = %e, "artifact list failed");
44                StatusCode::INTERNAL_SERVER_ERROR
45            })?;
46        Ok(Json(resp.file_names))
47    } else {
48        Ok(Json(vec![]))
49    }
50}
51
52pub async fn get_artifact(
53    State(controller): State<ArtifactsController>,
54    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
55    Path((app_name, user_id, session_id, artifact_name)): Path<(String, String, String, String)>,
56) -> Result<impl IntoResponse, StatusCode> {
57    let user_id = authorize_user_id(&request_context, &user_id)?;
58
59    if let Some(service) = &controller.config.artifact_service {
60        let resp = service
61            .load(LoadRequest {
62                app_name,
63                user_id,
64                session_id,
65                file_name: artifact_name.clone(),
66                version: None,
67            })
68            .await
69            .map_err(|e| {
70                tracing::error!(error = %e, "artifact get failed");
71                StatusCode::NOT_FOUND
72            })?;
73
74        let mime = mime_guess::from_path(&artifact_name).first_or_octet_stream();
75        let mime_header = header::HeaderValue::from_str(mime.as_ref())
76            .unwrap_or_else(|_| header::HeaderValue::from_static("application/octet-stream"));
77
78        match resp.part {
79            adk_core::Part::InlineData { data, .. } => {
80                Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(data)))
81            }
82            adk_core::Part::Text { text } => {
83                Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(text)))
84            }
85            _ => Err(StatusCode::INTERNAL_SERVER_ERROR),
86        }
87    } else {
88        Err(StatusCode::NOT_FOUND)
89    }
90}