Skip to main content

adk_server/rest/controllers/
artifacts.rs

1use crate::ServerConfig;
2use adk_artifact::{ListRequest, LoadRequest};
3use axum::{
4    Extension, Json,
5    body::Body,
6    extract::{Path, State},
7    http::{StatusCode, header},
8    response::IntoResponse,
9};
10
11#[derive(Clone)]
12pub struct ArtifactsController {
13    config: ServerConfig,
14}
15
16impl ArtifactsController {
17    pub fn new(config: ServerConfig) -> Self {
18        Self { config }
19    }
20}
21
22fn authorize_user_id(
23    request_context: &Option<adk_core::RequestContext>,
24    user_id: &str,
25) -> Result<String, StatusCode> {
26    match request_context {
27        Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
28        Some(context) => Ok(context.user_id.clone()),
29        None => Ok(user_id.to_string()),
30    }
31}
32
33pub async fn list_artifacts(
34    State(controller): State<ArtifactsController>,
35    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
36    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
37) -> Result<Json<Vec<String>>, StatusCode> {
38    let user_id = authorize_user_id(&request_context, &user_id)?;
39
40    if let Some(service) = &controller.config.artifact_service {
41        let resp = service
42            .list(ListRequest { app_name, user_id, session_id })
43            .await
44            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
45        Ok(Json(resp.file_names))
46    } else {
47        Ok(Json(vec![]))
48    }
49}
50
51pub async fn get_artifact(
52    State(controller): State<ArtifactsController>,
53    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
54    Path((app_name, user_id, session_id, artifact_name)): Path<(String, String, String, String)>,
55) -> Result<impl IntoResponse, StatusCode> {
56    let user_id = authorize_user_id(&request_context, &user_id)?;
57
58    if let Some(service) = &controller.config.artifact_service {
59        let resp = service
60            .load(LoadRequest {
61                app_name,
62                user_id,
63                session_id,
64                file_name: artifact_name.clone(),
65                version: None,
66            })
67            .await
68            .map_err(|_| StatusCode::NOT_FOUND)?;
69
70        let mime = mime_guess::from_path(&artifact_name).first_or_octet_stream();
71        let mime_header = header::HeaderValue::from_str(mime.as_ref())
72            .unwrap_or_else(|_| header::HeaderValue::from_static("application/octet-stream"));
73
74        match resp.part {
75            adk_core::Part::InlineData { data, .. } => {
76                Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(data)))
77            }
78            adk_core::Part::Text { text } => {
79                Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(text)))
80            }
81            _ => Err(StatusCode::INTERNAL_SERVER_ERROR),
82        }
83    } else {
84        Err(StatusCode::NOT_FOUND)
85    }
86}