Skip to main content

batuta/serve/banco/
handlers_registry.rs

1//! Model registry handlers — pull, list, and manage cached models via pacha.
2//!
3//! With `native` feature (includes pacha): real registry operations.
4//! Without: dry-run responses for API testing.
5
6use axum::{extract::State, http::StatusCode, response::Json};
7use serde::{Deserialize, Serialize};
8
9use super::state::BancoState;
10use super::types::ErrorResponse;
11
12/// POST /api/v1/models/pull — pull a model from the registry.
13pub async fn pull_model_handler(
14    State(_state): State<BancoState>,
15    Json(request): Json<PullRequest>,
16) -> Result<Json<PullResult>, (StatusCode, Json<ErrorResponse>)> {
17    pull_model(&request.model_ref)
18}
19
20/// GET /api/v1/models/registry — list cached models.
21pub async fn list_registry_handler(State(_state): State<BancoState>) -> Json<RegistryListResponse> {
22    Json(RegistryListResponse { models: list_cached_models() })
23}
24
25/// DELETE /api/v1/models/registry/:name — remove a model from cache.
26pub async fn remove_cached_model_handler(
27    State(_state): State<BancoState>,
28    axum::extract::Path(name): axum::extract::Path<String>,
29) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
30    remove_cached_model(&name)
31}
32
33// ============================================================================
34// pacha-powered registry (native feature)
35// ============================================================================
36
37#[cfg(feature = "native")]
38fn pull_model(model_ref: &str) -> Result<Json<PullResult>, (StatusCode, Json<ErrorResponse>)> {
39    let mut fetcher = pacha::fetcher::ModelFetcher::new().map_err(|e| {
40        (
41            StatusCode::INTERNAL_SERVER_ERROR,
42            Json(ErrorResponse::new(format!("Registry init failed: {e}"), "registry_error", 500)),
43        )
44    })?;
45
46    // Check if already cached
47    if fetcher.is_cached(model_ref) {
48        let cached = fetcher.list();
49        let info = cached.iter().find(|m| m.name == model_ref);
50        return Ok(Json(PullResult {
51            model_ref: model_ref.to_string(),
52            status: "cached".to_string(),
53            path: info.map(|m| m.path.display().to_string()),
54            size_bytes: info.map(|m| m.size_bytes),
55            cache_hit: true,
56            format: info.map(|m| format!("{:?}", m.format).to_lowercase()),
57        }));
58    }
59
60    // Pull from registry
61    match fetcher.pull_quiet(model_ref) {
62        Ok(result) => Ok(Json(PullResult {
63            model_ref: model_ref.to_string(),
64            status: "pulled".to_string(),
65            path: Some(result.path.display().to_string()),
66            size_bytes: Some(result.size_bytes),
67            cache_hit: result.cache_hit,
68            format: Some(format!("{:?}", result.format).to_lowercase()),
69        })),
70        Err(e) => Err((
71            StatusCode::NOT_FOUND,
72            Json(ErrorResponse::new(format!("Model not found: {e}"), "not_found", 404)),
73        )),
74    }
75}
76
77#[cfg(feature = "native")]
78fn list_cached_models() -> Vec<CachedModelInfo> {
79    let fetcher = match pacha::fetcher::ModelFetcher::new() {
80        Ok(f) => f,
81        Err(_) => return Vec::new(),
82    };
83
84    fetcher
85        .list()
86        .into_iter()
87        .map(|m| CachedModelInfo {
88            name: m.name.clone(),
89            version: m.version.clone(),
90            path: m.path.display().to_string(),
91            size_bytes: m.size_bytes,
92            format: format!("{:?}", m.format).to_lowercase(),
93        })
94        .collect()
95}
96
97#[cfg(feature = "native")]
98fn remove_cached_model(name: &str) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
99    let mut fetcher = pacha::fetcher::ModelFetcher::new().map_err(|e| {
100        (
101            StatusCode::INTERNAL_SERVER_ERROR,
102            Json(ErrorResponse::new(format!("Registry: {e}"), "registry_error", 500)),
103        )
104    })?;
105
106    match fetcher.remove(name) {
107        Ok(true) => Ok(StatusCode::NO_CONTENT),
108        Ok(false) => Err((
109            StatusCode::NOT_FOUND,
110            Json(ErrorResponse::new(format!("Model {name} not in cache"), "not_found", 404)),
111        )),
112        Err(e) => Err((
113            StatusCode::INTERNAL_SERVER_ERROR,
114            Json(ErrorResponse::new(format!("Remove failed: {e}"), "registry_error", 500)),
115        )),
116    }
117}
118
119// ============================================================================
120// Dry-run registry (no native feature)
121// ============================================================================
122
123#[cfg(not(feature = "native"))]
124fn pull_model(model_ref: &str) -> Result<Json<PullResult>, (StatusCode, Json<ErrorResponse>)> {
125    Ok(Json(PullResult {
126        model_ref: model_ref.to_string(),
127        status: "dry_run".to_string(),
128        path: None,
129        size_bytes: None,
130        cache_hit: false,
131        format: None,
132    }))
133}
134
135#[cfg(not(feature = "native"))]
136fn list_cached_models() -> Vec<CachedModelInfo> {
137    Vec::new()
138}
139
140#[cfg(not(feature = "native"))]
141fn remove_cached_model(_name: &str) -> Result<StatusCode, (StatusCode, Json<ErrorResponse>)> {
142    Ok(StatusCode::NO_CONTENT)
143}
144
145// ============================================================================
146// Types
147// ============================================================================
148
149/// Model pull request.
150#[derive(Debug, Deserialize)]
151pub struct PullRequest {
152    /// Model reference: "llama3:8b-q4", "pacha://model:version", or file path.
153    pub model_ref: String,
154}
155
156/// Model pull result.
157#[derive(Debug, Clone, Serialize)]
158pub struct PullResult {
159    pub model_ref: String,
160    pub status: String,
161    #[serde(skip_serializing_if = "Option::is_none")]
162    pub path: Option<String>,
163    #[serde(skip_serializing_if = "Option::is_none")]
164    pub size_bytes: Option<u64>,
165    pub cache_hit: bool,
166    #[serde(skip_serializing_if = "Option::is_none")]
167    pub format: Option<String>,
168}
169
170/// Cached model info.
171#[derive(Debug, Clone, Serialize)]
172pub struct CachedModelInfo {
173    pub name: String,
174    pub version: String,
175    pub path: String,
176    pub size_bytes: u64,
177    pub format: String,
178}
179
180/// Registry list response.
181#[derive(Debug, Serialize)]
182pub struct RegistryListResponse {
183    pub models: Vec<CachedModelInfo>,
184}