adk_server/rest/controllers/
artifacts.rs1use crate::ServerConfig;
2use adk_artifact::{ListRequest, LoadRequest};
3use axum::{
4 Extension, Json,
5 body::Body,
6 extract::{Path, State},
7 http::{StatusCode, header},
8 response::IntoResponse,
9};
10
11#[derive(Clone)]
12pub struct ArtifactsController {
13 config: ServerConfig,
14}
15
16impl ArtifactsController {
17 pub fn new(config: ServerConfig) -> Self {
18 Self { config }
19 }
20}
21
22fn authorize_user_id(
23 request_context: &Option<adk_core::RequestContext>,
24 user_id: &str,
25) -> Result<String, StatusCode> {
26 match request_context {
27 Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
28 Some(context) => Ok(context.user_id.clone()),
29 None => Ok(user_id.to_string()),
30 }
31}
32
33pub async fn list_artifacts(
34 State(controller): State<ArtifactsController>,
35 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
36 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
37) -> Result<Json<Vec<String>>, StatusCode> {
38 let user_id = authorize_user_id(&request_context, &user_id)?;
39
40 if let Some(service) = &controller.config.artifact_service {
41 let resp = service
42 .list(ListRequest { app_name, user_id, session_id })
43 .await
44 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
45 Ok(Json(resp.file_names))
46 } else {
47 Ok(Json(vec![]))
48 }
49}
50
51pub async fn get_artifact(
52 State(controller): State<ArtifactsController>,
53 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
54 Path((app_name, user_id, session_id, artifact_name)): Path<(String, String, String, String)>,
55) -> Result<impl IntoResponse, StatusCode> {
56 let user_id = authorize_user_id(&request_context, &user_id)?;
57
58 if let Some(service) = &controller.config.artifact_service {
59 let resp = service
60 .load(LoadRequest {
61 app_name,
62 user_id,
63 session_id,
64 file_name: artifact_name.clone(),
65 version: None,
66 })
67 .await
68 .map_err(|_| StatusCode::NOT_FOUND)?;
69
70 let mime = mime_guess::from_path(&artifact_name).first_or_octet_stream();
71 let mime_header = header::HeaderValue::from_str(mime.as_ref())
72 .unwrap_or_else(|_| header::HeaderValue::from_static("application/octet-stream"));
73
74 match resp.part {
75 adk_core::Part::InlineData { data, .. } => {
76 Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(data)))
77 }
78 adk_core::Part::Text { text } => {
79 Ok(([(header::CONTENT_TYPE, mime_header)], Body::from(text)))
80 }
81 _ => Err(StatusCode::INTERNAL_SERVER_ERROR),
82 }
83 } else {
84 Err(StatusCode::NOT_FOUND)
85 }
86}