1use axum::{extract::State, http::StatusCode, response::Json};
7use serde::{Deserialize, Serialize};
8
9use super::state::BancoState;
10use super::types::ErrorResponse;
11
12pub async fn pull_model_handler(
14 State(_state): State<BancoState>,
15 Json(request): Json<PullRequest>,
16) -> Result<Json<PullResult>, (StatusCode, Json<ErrorResponse>)> {
17 pull_model(&request.model_ref)
18}
19
20pub async fn list_registry_handler(State(_state): State<BancoState>) -> Json<RegistryListResponse> {
22 Json(RegistryListResponse { models: list_cached_models() })
23}
24
25pub async fn remove_cached_model_handler(
27 State(_state): State<BancoState>,
28 axum::extract::Path(name): axum::extract::Path<String>,
29) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
30 remove_cached_model(&name)
31}
32
33#[cfg(feature = "native")]
38fn pull_model(model_ref: &str) -> Result<Json<PullResult>, (StatusCode, Json<ErrorResponse>)> {
39 let mut fetcher = pacha::fetcher::ModelFetcher::new().map_err(|e| {
40 (
41 StatusCode::INTERNAL_SERVER_ERROR,
42 Json(ErrorResponse::new(format!("Registry init failed: {e}"), "registry_error", 500)),
43 )
44 })?;
45
46 if fetcher.is_cached(model_ref) {
48 let cached = fetcher.list();
49 let info = cached.iter().find(|m| m.name == model_ref);
50 return Ok(Json(PullResult {
51 model_ref: model_ref.to_string(),
52 status: "cached".to_string(),
53 path: info.map(|m| m.path.display().to_string()),
54 size_bytes: info.map(|m| m.size_bytes),
55 cache_hit: true,
56 format: info.map(|m| format!("{:?}", m.format).to_lowercase()),
57 }));
58 }
59
60 match fetcher.pull_quiet(model_ref) {
62 Ok(result) => Ok(Json(PullResult {
63 model_ref: model_ref.to_string(),
64 status: "pulled".to_string(),
65 path: Some(result.path.display().to_string()),
66 size_bytes: Some(result.size_bytes),
67 cache_hit: result.cache_hit,
68 format: Some(format!("{:?}", result.format).to_lowercase()),
69 })),
70 Err(e) => Err((
71 StatusCode::NOT_FOUND,
72 Json(ErrorResponse::new(format!("Model not found: {e}"), "not_found", 404)),
73 )),
74 }
75}
76
77#[cfg(feature = "native")]
78fn list_cached_models() -> Vec<CachedModelInfo> {
79 let fetcher = match pacha::fetcher::ModelFetcher::new() {
80 Ok(f) => f,
81 Err(_) => return Vec::new(),
82 };
83
84 fetcher
85 .list()
86 .into_iter()
87 .map(|m| CachedModelInfo {
88 name: m.name.clone(),
89 version: m.version.clone(),
90 path: m.path.display().to_string(),
91 size_bytes: m.size_bytes,
92 format: format!("{:?}", m.format).to_lowercase(),
93 })
94 .collect()
95}
96
97#[cfg(feature = "native")]
98fn remove_cached_model(name: &str) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
99 let mut fetcher = pacha::fetcher::ModelFetcher::new().map_err(|e| {
100 (
101 StatusCode::INTERNAL_SERVER_ERROR,
102 Json(ErrorResponse::new(format!("Registry: {e}"), "registry_error", 500)),
103 )
104 })?;
105
106 match fetcher.remove(name) {
107 Ok(true) => Ok(StatusCode::NO_CONTENT),
108 Ok(false) => Err((
109 StatusCode::NOT_FOUND,
110 Json(ErrorResponse::new(format!("Model {name} not in cache"), "not_found", 404)),
111 )),
112 Err(e) => Err((
113 StatusCode::INTERNAL_SERVER_ERROR,
114 Json(ErrorResponse::new(format!("Remove failed: {e}"), "registry_error", 500)),
115 )),
116 }
117}
118
119#[cfg(not(feature = "native"))]
124fn pull_model(model_ref: &str) -> Result<Json<PullResult>, (StatusCode, Json<ErrorResponse>)> {
125 Ok(Json(PullResult {
126 model_ref: model_ref.to_string(),
127 status: "dry_run".to_string(),
128 path: None,
129 size_bytes: None,
130 cache_hit: false,
131 format: None,
132 }))
133}
134
135#[cfg(not(feature = "native"))]
136fn list_cached_models() -> Vec<CachedModelInfo> {
137 Vec::new()
138}
139
140#[cfg(not(feature = "native"))]
141fn remove_cached_model(_name: &str) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
142 Ok(StatusCode::NO_CONTENT)
143}
144
145#[derive(Debug, Deserialize)]
151pub struct PullRequest {
152 pub model_ref: String,
154}
155
156#[derive(Debug, Clone, Serialize)]
158pub struct PullResult {
159 pub model_ref: String,
160 pub status: String,
161 #[serde(skip_serializing_if = "Option::is_none")]
162 pub path: Option<String>,
163 #[serde(skip_serializing_if = "Option::is_none")]
164 pub size_bytes: Option<u64>,
165 pub cache_hit: bool,
166 #[serde(skip_serializing_if = "Option::is_none")]
167 pub format: Option<String>,
168}
169
170#[derive(Debug, Clone, Serialize)]
172pub struct CachedModelInfo {
173 pub name: String,
174 pub version: String,
175 pub path: String,
176 pub size_bytes: u64,
177 pub format: String,
178}
179
180#[derive(Debug, Serialize)]
182pub struct RegistryListResponse {
183 pub models: Vec<CachedModelInfo>,
184}