adk_server/rest/controllers/
artifacts.rs1use crate::ServerConfig;
2use adk_artifact::{ListRequest, LoadRequest};
3use axum::{
4 Extension, Json,
5 body::Body,
6 extract::{Path, State},
7 http::{StatusCode, header},
8 response::IntoResponse,
9};
10
11#[derive(Clone)]
12pub struct ArtifactsController {
13 config: ServerConfig,
14}
15
16impl ArtifactsController {
17 pub fn new(config: ServerConfig) -> Self {
18 Self { config }
19 }
20}
21
22fn authorize_user_id(
23 request_context: &Option<adk_core::RequestContext>,
24 user_id: &str,
25) -> Result<String, StatusCode> {
26 match request_context {
27 Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
28 Some(context) => Ok(context.user_id.clone()),
29 None => Ok(user_id.to_string()),
30 }
31}
32
33pub async fn list_artifacts(
34 State(controller): State<ArtifactsController>,
35 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
36 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
37) -> Result<Json<Vec<String>>, StatusCode> {
38 let user_id = authorize_user_id(&request_context, &user_id)?;
39
40 if let Some(service) = &controller.config.artifact_service {
41 let resp =
42 service.list(ListRequest { app_name, user_id, session_id }).await.map_err(|e| {
43 tracing::error!(error = %e, "artifact list failed");
44 StatusCode::INTERNAL_SERVER_ERROR
45 })?;
46 Ok(Json(resp.file_names))
47 } else {
48 Ok(Json(vec![]))
49 }
50}
51
52pub async fn get_artifact(
53 State(controller): State<ArtifactsController>,
54 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
55 Path((app_name, user_id, session_id, artifact_name)): Path<(String, String, String, String)>,
56) -> Result<impl IntoResponse, StatusCode> {
57 let user_id = authorize_user_id(&request_context, &user_id)?;
58
59 if let Some(service) = &controller.config.artifact_service {
60 let resp = service
61 .load(LoadRequest {
62 app_name,
63 user_id,
64 session_id,
65 file_name: artifact_name.clone(),
66 version: None,
67 })
68 .await
69 .map_err(|e| {
70 tracing::error!(error = %e, "artifact get failed");
71 StatusCode::NOT_FOUND
72 })?;
73
74 let mime = mime_guess::from_path(&artifact_name).first_or_octet_stream();
75 let mime_header = header::HeaderValue::from_str(mime.as_ref())
76 .unwrap_or_else(|_| header::HeaderValue::from_static("application/octet-stream"));
77
78 match resp.part {
79 adk_core::Part::InlineData { data, .. } => {
80 Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(data)))
81 }
82 adk_core::Part::Text { text } => {
83 Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(text)))
84 }
85 _ => Err(StatusCode::INTERNAL_SERVER_ERROR),
86 }
87 } else {
88 Err(StatusCode::NOT_FOUND)
89 }
90}