use crate::cache::CacheManager;
use crate::model::SdkError;
use crate::platform::current_platform;
use crate::source::detect_platform;
use crate::telemetry_optout::is_telemetry_opted_out;
use crate::{get_binding, DEFAULT_BINDING};
use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use sha2::{Digest, Sha256};
use std::collections::HashMap;
use std::fs::File;
use std::io::{BufReader, Read, Write};
use std::path::PathBuf;
use std::sync::Arc;
use std::time::{Duration, Instant};
use xybrid_core::http::{CircuitBreaker, CircuitConfig, RetryPolicy, RetryableError};
pub const DEFAULT_REGISTRY_URL: &str = "https://registry.xybrid.dev";
pub const FALLBACK_REGISTRY_URL: &str = "https://r2.xybrid.dev";
pub const REGISTRY_URLS: &[&str] = &[DEFAULT_REGISTRY_URL, FALLBACK_REGISTRY_URL];
pub const CLIENT_HEADER_NAME: &str = "X-Xybrid-Client";
pub fn build_client_header(binding: &str) -> Option<String> {
build_client_header_with_optout(binding, is_telemetry_opted_out())
}
fn build_client_header_with_optout(binding: &str, opted_out: bool) -> Option<String> {
if opted_out {
return None;
}
let safe_binding = sanitize_binding(binding);
let backends = xybrid_core::features::enabled().join(",");
Some(format!(
"binding={}; sdk_version={}; core_version={}; platform={}; backends={}",
safe_binding,
env!("CARGO_PKG_VERSION"),
xybrid_core::VERSION,
current_platform(),
backends,
))
}
fn classify_download_source(url: &str) -> &'static str {
let lower = url.to_ascii_lowercase();
if lower.contains("huggingface.co") || lower.contains("hf.co/") {
"huggingface"
} else if lower.contains("r2.xybrid.dev")
|| lower.contains("r2.cloudflarestorage.com")
|| lower.contains(".r2.dev")
{
"r2"
} else {
"other"
}
}
fn sanitize_binding(binding: &str) -> &str {
let valid = !binding.is_empty()
&& binding
.chars()
.all(|c| c.is_ascii_lowercase() || c.is_ascii_digit() || c == '_' || c == '-');
if valid {
binding
} else {
DEFAULT_BINDING
}
}
const CONNECT_TIMEOUT_MS: u64 = 5000;
const REQUEST_TIMEOUT_MS: u64 = 15000;
pub struct RegistryClient {
api_urls: Vec<String>,
cache: CacheManager,
agent: ureq::Agent,
circuits: Vec<Arc<CircuitBreaker>>,
retry_policy: RetryPolicy,
binding: &'static str,
}
impl RegistryClient {
pub fn new(api_urls: Vec<String>) -> Result<Self, SdkError> {
if api_urls.is_empty() {
return Err(SdkError::ConfigError(
"No registry URLs provided".to_string(),
));
}
let agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_millis(CONNECT_TIMEOUT_MS))
.timeout(Duration::from_millis(REQUEST_TIMEOUT_MS))
.build();
let cache = CacheManager::new()?;
let circuits: Vec<Arc<CircuitBreaker>> = api_urls
.iter()
.map(|_| Arc::new(CircuitBreaker::new(CircuitConfig::default())))
.collect();
debug!(
"RegistryClient created with {} URLs, cache_dir={}",
api_urls.len(),
cache.cache_dir().display()
);
Ok(Self {
api_urls,
cache,
agent,
circuits,
retry_policy: RetryPolicy::default(),
binding: get_binding(),
})
}
pub fn with_binding(mut self, binding: &'static str) -> Self {
self.binding = binding;
self
}
pub fn binding(&self) -> &'static str {
self.binding
}
fn apply_client_header(&self, req: ureq::Request) -> ureq::Request {
self.apply_client_header_with_optout(req, is_telemetry_opted_out())
}
fn apply_client_header_with_optout(
&self,
req: ureq::Request,
opted_out: bool,
) -> ureq::Request {
match build_client_header_with_optout(self.binding, opted_out) {
Some(value) => req.set(CLIENT_HEADER_NAME, &value),
None => req,
}
}
pub fn with_url(api_url: impl Into<String>) -> Result<Self, SdkError> {
Self::new(vec![api_url.into()])
}
pub fn default_client() -> Result<Self, SdkError> {
Self::new(REGISTRY_URLS.iter().map(|s| s.to_string()).collect())
}
pub fn from_env() -> Result<Self, SdkError> {
if let Ok(url) = std::env::var("XYBRID_REGISTRY_URL") {
Self::with_url(url)
} else {
Self::default_client()
}
}
pub fn primary_url(&self) -> &str {
&self.api_urls[0]
}
pub fn is_circuit_open(&self) -> bool {
self.circuits.iter().all(|c| c.is_open())
}
pub fn reset_circuit(&self) {
for circuit in &self.circuits {
circuit.reset();
}
}
pub fn list_models(&self) -> Result<Vec<ModelSummary>, SdkError> {
self.execute_with_fallback(|api_url| {
let url = format!("{}/v1/models", api_url);
let req = self.apply_client_header(self.agent.get(&url));
let response = req.call();
self.handle_response(response, "list models")
})
.and_then(|response| {
let list_response: ListModelsResponse = response
.into_json()
.map_err(|e| SdkError::NetworkError(format!("Failed to parse response: {}", e)))?;
Ok(list_response.models)
})
}
pub fn get_model(&self, mask: &str) -> Result<ModelDetail, SdkError> {
self.execute_with_fallback(|api_url| {
let url = format!("{}/v1/models/{}", api_url, mask);
let req = self.apply_client_header(self.agent.get(&url));
let response = req.call();
self.handle_response_with_404(response, "get model", || {
SdkError::ModelNotFound(format!("Model '{}' not found", mask))
})
})
.and_then(|response| {
response
.into_json()
.map_err(|e| SdkError::NetworkError(format!("Failed to parse response: {}", e)))
})
}
pub fn resolve(&self, mask: &str, platform: Option<&str>) -> Result<ResolvedVariant, SdkError> {
let platform = platform.map(String::from).unwrap_or_else(detect_platform);
self.execute_with_fallback(|api_url| {
let url = format!(
"{}/v1/models/{}/resolve?platform={}",
api_url, mask, platform
);
let req = self.apply_client_header(self.agent.get(&url));
let response = req.call();
self.handle_response_with_404(response, "resolve model", || {
SdkError::ModelNotFound(format!(
"Model '{}' not found or no compatible variant for platform '{}'",
mask, platform
))
})
})
.and_then(|response| {
let resolve_response: ResolveResponse = response
.into_json()
.map_err(|e| SdkError::NetworkError(format!("Failed to parse response: {}", e)))?;
Ok(resolve_response.resolved)
})
}
fn execute_with_fallback<T, F>(&self, mut operation: F) -> Result<T, SdkError>
where
F: FnMut(&str) -> Result<T, SdkError>,
{
let mut last_error: Option<SdkError> = None;
for (idx, api_url) in self.api_urls.iter().enumerate() {
let circuit = &self.circuits[idx];
if !circuit.can_execute() {
debug!("Skipping {} (circuit open)", api_url);
continue;
}
match self.execute_with_retry_for_url(api_url, circuit, &mut operation) {
Ok(result) => {
if idx > 0 {
info!("Request succeeded using fallback URL: {}", api_url);
}
return Ok(result);
}
Err(err) => {
if !err.is_retryable() {
return Err(err);
}
debug!("URL {} failed: {}, trying next", api_url, err);
last_error = Some(err);
}
}
}
Err(last_error.unwrap_or_else(|| {
SdkError::NetworkError("All registry URLs failed or circuits open".to_string())
}))
}
fn execute_with_retry_for_url<T, F>(
&self,
api_url: &str,
circuit: &Arc<CircuitBreaker>,
operation: &mut F,
) -> Result<T, SdkError>
where
F: FnMut(&str) -> Result<T, SdkError>,
{
let mut last_error: Option<SdkError> = None;
for attempt in 0..self.retry_policy.max_attempts {
let delay = if let Some(ref err) = last_error {
err.retry_after()
.unwrap_or_else(|| self.retry_policy.delay_for_attempt(attempt))
} else {
self.retry_policy.delay_for_attempt(attempt)
};
if !delay.is_zero() {
std::thread::sleep(delay);
}
if !circuit.can_execute() {
return Err(SdkError::CircuitOpen(format!(
"Circuit breaker open for {}",
api_url
)));
}
match operation(api_url) {
Ok(result) => {
circuit.record_success();
return Ok(result);
}
Err(err) => {
if matches!(&err, SdkError::Offline(_)) {
return Err(err);
}
circuit.record_failure();
if let SdkError::RateLimited { .. } = &err {
circuit.record_rate_limited();
}
if !err.is_retryable() {
return Err(err);
}
last_error = Some(err);
}
}
}
Err(last_error.unwrap_or_else(|| {
SdkError::NetworkError(format!("All retry attempts exhausted for {}", api_url))
}))
}
fn handle_response(
&self,
response: Result<ureq::Response, ureq::Error>,
operation: &str,
) -> Result<ureq::Response, SdkError> {
match response {
Ok(resp) => {
if resp.status() == 200 {
Ok(resp)
} else {
Err(self.status_to_error(resp.status(), operation))
}
}
Err(e) => Err(self.ureq_error_to_sdk_error(e, operation)),
}
}
fn handle_response_with_404<F>(
&self,
response: Result<ureq::Response, ureq::Error>,
operation: &str,
not_found_err: F,
) -> Result<ureq::Response, SdkError>
where
F: FnOnce() -> SdkError,
{
match response {
Ok(resp) => {
if resp.status() == 200 {
Ok(resp)
} else if resp.status() == 404 {
Err(not_found_err())
} else {
Err(self.status_to_error(resp.status(), operation))
}
}
Err(ureq::Error::Status(404, _)) => Err(not_found_err()),
Err(e) => Err(self.ureq_error_to_sdk_error(e, operation)),
}
}
fn status_to_error(&self, status: u16, operation: &str) -> SdkError {
match status {
429 => {
SdkError::RateLimited {
retry_after_secs: 60,
}
}
502..=504 => SdkError::NetworkError(format!(
"Registry {} failed with status {} (server error)",
operation, status
)),
400 | 401 | 403 | 422 => SdkError::ConfigError(format!(
"Registry {} failed with status {} (client error)",
operation, status
)),
_ => {
SdkError::NetworkError(format!("Registry {} returned status {}", operation, status))
}
}
}
fn ureq_error_to_sdk_error(&self, error: ureq::Error, operation: &str) -> SdkError {
match error {
ureq::Error::Status(status, _) => self.status_to_error(status, operation),
ureq::Error::Transport(transport) => {
let kind = transport.kind();
match kind {
ureq::ErrorKind::Dns => SdkError::Offline(format!(
"Failed to {} (DNS resolution failed)",
operation
)),
ureq::ErrorKind::ConnectionFailed => SdkError::Offline(format!(
"Failed to {} (connection refused or host unreachable)",
operation
)),
ureq::ErrorKind::Io => SdkError::Offline(format!(
"Failed to {} (network I/O error: {})",
operation,
transport.message().unwrap_or("unknown")
)),
_ => SdkError::NetworkError(format!("Failed to {}: {}", operation, transport)),
}
}
}
}
pub fn is_cached(&self, mask: &str, platform: Option<&str>) -> Result<bool, SdkError> {
let resolved = self.resolve(mask, platform)?;
let cache_path = self.get_cache_path(&resolved);
if !cache_path.exists() {
return Ok(false);
}
if !resolved.sha256.is_empty() {
let hash = compute_sha256(&cache_path)?;
Ok(hash == resolved.sha256)
} else {
Ok(true)
}
}
pub fn get_cache_path(&self, resolved: &ResolvedVariant) -> PathBuf {
let model_name = resolved
.hf_repo
.split('/')
.next_back()
.unwrap_or(&resolved.hf_repo);
self.cache.cache_dir().join(model_name).join(&resolved.file)
}
pub fn fetch<F>(
&self,
mask: &str,
platform: Option<&str>,
progress_callback: F,
) -> Result<PathBuf, SdkError>
where
F: Fn(f32),
{
let resolved = self.resolve(mask, platform)?;
let cache_path = self.get_cache_path(&resolved);
debug!(
"Cache check for '{}': path={}, exists={}, sha256_provided={}",
mask,
cache_path.display(),
cache_path.exists(),
!resolved.sha256.is_empty()
);
if cache_path.exists() && !resolved.sha256.is_empty() {
let hash = match read_cached_hash(&cache_path) {
Some(cached_hash) => {
debug!("Using cached hash for '{}'", mask);
cached_hash
}
None => {
debug!("Computing hash for '{}' (no cached hash found)", mask);
let computed = compute_sha256(&cache_path)?;
write_cached_hash(&cache_path, &computed);
computed
}
};
debug!(
"Cache verification for '{}': expected={}, actual={}",
mask, resolved.sha256, hash
);
if hash == resolved.sha256 {
info!("Cache hit for '{}' at {}", mask, cache_path.display());
return Ok(cache_path);
}
info!("Cache hash mismatch for '{}', re-downloading", mask);
std::fs::remove_file(&cache_path).ok();
remove_cached_hash(&cache_path);
} else if cache_path.exists() {
info!(
"Cache exists for '{}' but no sha256 to verify, re-downloading",
mask
);
} else {
info!(
"Cache miss for '{}', downloading to {}",
mask,
cache_path.display()
);
}
if let Some(parent) = cache_path.parent() {
std::fs::create_dir_all(parent)?;
}
info!("Downloading '{}' from {}", mask, resolved.download_url);
let download_started = Instant::now();
self.download_with_progress(
&resolved.download_url,
&cache_path,
resolved.size_bytes,
progress_callback,
)?;
let download_duration = download_started.elapsed();
let bytes_downloaded = std::fs::metadata(&cache_path)
.map(|m| m.len())
.unwrap_or(resolved.size_bytes);
crate::telemetry::publish_model_download(
mask,
bytes_downloaded,
classify_download_source(&resolved.download_url),
download_duration.as_millis().min(u32::MAX as u128) as u32,
);
if !resolved.sha256.is_empty() {
let hash = compute_sha256(&cache_path)?;
if hash != resolved.sha256 {
std::fs::remove_file(&cache_path).ok();
return Err(SdkError::CacheError(format!(
"SHA256 mismatch: expected {}, got {}",
resolved.sha256, hash
)));
}
write_cached_hash(&cache_path, &hash);
info!(
"Download complete for '{}', SHA256 verified, cached at {}",
mask,
cache_path.display()
);
} else {
info!(
"Download complete for '{}' (no SHA256 verification), cached at {}",
mask,
cache_path.display()
);
}
Ok(cache_path)
}
pub fn fetch_extracted<F>(
&self,
mask: &str,
platform: Option<&str>,
progress_callback: F,
) -> Result<PathBuf, SdkError>
where
F: Fn(f32),
{
if let Some(extract_dir) = self.resolve_offline(mask) {
debug!(
"Using locally extracted model '{}' at {} (skipping registry)",
mask,
extract_dir.display()
);
return Ok(extract_dir);
}
let resolved = self.resolve(mask, platform)?;
if resolved.passthrough {
self.fetch_passthrough(mask, &resolved, progress_callback)
} else {
let xyb_path = self.fetch(mask, platform, progress_callback)?;
self.cache.ensure_extracted(&xyb_path)
}
}
fn fetch_passthrough<F>(
&self,
mask: &str,
resolved: &ResolvedVariant,
progress_callback: F,
) -> Result<PathBuf, SdkError>
where
F: Fn(f32),
{
let extract_dir = self.cache.extraction_dir(mask);
let model_file_path = extract_dir.join(&resolved.file);
let metadata_path = extract_dir.join("model_metadata.json");
if model_file_path.exists() && metadata_path.exists() {
if resolved.sha256.is_empty() {
warn!(
"Passthrough cache hit for '{}' (no hash verification) at {}",
mask,
extract_dir.display()
);
return Ok(extract_dir);
}
if let Some(cached_hash) = read_cached_hash(&model_file_path) {
if cached_hash == resolved.sha256 {
info!(
"Passthrough cache hit for '{}' at {}",
mask,
extract_dir.display()
);
return Ok(extract_dir);
}
info!("Passthrough hash mismatch for '{}', re-downloading", mask);
}
}
std::fs::create_dir_all(&extract_dir).map_err(|e| {
SdkError::CacheError(format!("Failed to create extraction directory: {}", e))
})?;
info!(
"Passthrough download '{}' from {}",
mask, resolved.download_url
);
let download_started = Instant::now();
self.download_with_progress(
&resolved.download_url,
&model_file_path,
resolved.size_bytes,
&progress_callback,
)?;
let download_duration = download_started.elapsed();
let bytes_downloaded = std::fs::metadata(&model_file_path)
.map(|m| m.len())
.unwrap_or(resolved.size_bytes);
crate::telemetry::publish_model_download(
mask,
bytes_downloaded,
classify_download_source(&resolved.download_url),
download_duration.as_millis().min(u32::MAX as u128) as u32,
);
if !resolved.sha256.is_empty() {
let hash = compute_sha256(&model_file_path)?;
if hash != resolved.sha256 {
std::fs::remove_file(&model_file_path).ok();
return Err(SdkError::CacheError(format!(
"Passthrough SHA256 mismatch: expected {}, got {}",
resolved.sha256, hash
)));
}
write_cached_hash(&model_file_path, &hash);
info!("Passthrough SHA256 verified for '{}'", mask);
}
if let Some(ref metadata) = resolved.model_metadata {
let metadata_json = serde_json::to_string_pretty(metadata).map_err(|e| {
SdkError::CacheError(format!("Failed to serialize model metadata: {}", e))
})?;
std::fs::write(&metadata_path, metadata_json).map_err(|e| {
SdkError::CacheError(format!("Failed to write model_metadata.json: {}", e))
})?;
info!(
"Wrote model_metadata.json for passthrough model '{}' at {}",
mask,
metadata_path.display()
);
} else {
return Err(SdkError::CacheError(format!(
"Passthrough variant for '{}' has no model_metadata in registry response",
mask
)));
}
Ok(extract_dir)
}
pub fn is_extracted(&self, model_id: &str) -> bool {
self.cache.is_extracted(model_id)
}
pub fn extraction_dir(&self, model_id: &str) -> PathBuf {
self.cache.extraction_dir(model_id)
}
pub fn resolve_offline(&self, mask: &str) -> Option<PathBuf> {
if self.cache.is_extracted(mask) {
Some(self.cache.extraction_dir(mask))
} else {
None
}
}
pub fn list_offline_models(&self) -> Vec<String> {
self.cache.list_extracted_model_ids()
}
fn download_with_progress<F>(
&self,
url: &str,
dest: &PathBuf,
total_size: u64,
progress_callback: F,
) -> Result<(), SdkError>
where
F: Fn(f32),
{
let download_policy = RetryPolicy::conservative();
let mut last_error: Option<SdkError> = None;
for attempt in 0..download_policy.max_attempts {
let delay = if let Some(ref err) = last_error {
err.retry_after()
.unwrap_or_else(|| download_policy.delay_for_attempt(attempt))
} else {
download_policy.delay_for_attempt(attempt)
};
if !delay.is_zero() {
std::thread::sleep(delay);
}
match self.try_download(url, dest, total_size, &progress_callback) {
Ok(()) => return Ok(()),
Err(err) => {
if !err.is_retryable() {
return Err(err);
}
std::fs::remove_file(dest).ok();
last_error = Some(err);
}
}
}
Err(last_error.unwrap_or_else(|| {
SdkError::NetworkError("Download failed after all retry attempts".to_string())
}))
}
fn try_download<F>(
&self,
url: &str,
dest: &PathBuf,
total_size: u64,
progress_callback: &F,
) -> Result<(), SdkError>
where
F: Fn(f32),
{
let download_agent = ureq::AgentBuilder::new()
.timeout_connect(Duration::from_millis(CONNECT_TIMEOUT_MS))
.timeout(Duration::from_secs(300)) .build();
let response = download_agent
.get(url)
.call()
.map_err(|e| self.ureq_error_to_sdk_error(e, "download bundle"))?;
if response.status() != 200 {
return Err(self.status_to_error(response.status(), "download bundle"));
}
let mut file = File::create(dest)?;
let mut reader = response.into_reader();
let mut buffer = [0u8; 8192];
let mut downloaded: u64 = 0;
loop {
let bytes_read = reader
.read(&mut buffer)
.map_err(|e| SdkError::NetworkError(format!("Read error: {}", e)))?;
if bytes_read == 0 {
break;
}
file.write_all(&buffer[..bytes_read])?;
downloaded += bytes_read as u64;
if total_size > 0 {
let progress = downloaded as f32 / total_size as f32;
progress_callback(progress.min(1.0));
}
}
progress_callback(1.0);
Ok(())
}
pub fn clear_cache(&self, mask: &str) -> Result<(), SdkError> {
let model_dir = self.cache.cache_dir().join(mask);
if model_dir.exists() {
std::fs::remove_dir_all(&model_dir)?;
}
Ok(())
}
pub fn clear_all_cache(&mut self) -> Result<(), SdkError> {
self.cache
.clear()
.map_err(|e| SdkError::CacheError(e.to_string()))?;
Ok(())
}
pub fn cache_stats(&self) -> Result<CacheStats, SdkError> {
let cache_dir = self.cache.cache_dir();
let mut total_size: u64 = 0;
let mut model_count: usize = 0;
if cache_dir.exists() {
for entry in std::fs::read_dir(cache_dir)? {
let entry = entry?;
if entry.path().is_dir() {
model_count += 1;
total_size += dir_size(&entry.path())?;
}
}
}
Ok(CacheStats {
total_size_bytes: total_size,
model_count,
cache_path: cache_dir.to_path_buf(),
})
}
}
fn compute_sha256(path: &PathBuf) -> Result<String, SdkError> {
let file = File::open(path)?;
let mut reader = BufReader::new(file);
let mut hasher = Sha256::new();
let mut buffer = [0u8; 8192];
loop {
let bytes_read = reader.read(&mut buffer)?;
if bytes_read == 0 {
break;
}
hasher.update(&buffer[..bytes_read]);
}
Ok(format!("{:x}", hasher.finalize()))
}
fn hash_cache_path(file_path: &PathBuf) -> PathBuf {
let mut sidecar = file_path.as_os_str().to_os_string();
sidecar.push(".sha256");
PathBuf::from(sidecar)
}
fn read_cached_hash(bundle_path: &PathBuf) -> Option<String> {
let hash_path = hash_cache_path(bundle_path);
if !hash_path.exists() {
return None;
}
let bundle_mtime = std::fs::metadata(bundle_path).ok()?.modified().ok()?;
let hash_mtime = std::fs::metadata(&hash_path).ok()?.modified().ok()?;
if bundle_mtime > hash_mtime {
return None;
}
let hash = std::fs::read_to_string(&hash_path).ok()?;
let hash = hash.trim();
if hash.len() == 64 && hash.chars().all(|c| c.is_ascii_hexdigit()) {
Some(hash.to_string())
} else {
None
}
}
fn write_cached_hash(bundle_path: &PathBuf, hash: &str) {
let hash_path = hash_cache_path(bundle_path);
let _ = std::fs::write(&hash_path, hash);
}
fn remove_cached_hash(bundle_path: &PathBuf) {
let hash_path = hash_cache_path(bundle_path);
let _ = std::fs::remove_file(&hash_path);
}
fn dir_size(path: &PathBuf) -> Result<u64, SdkError> {
let mut total: u64 = 0;
for entry in std::fs::read_dir(path)? {
let entry = entry?;
let metadata = entry.metadata()?;
if metadata.is_file() {
total += metadata.len();
} else if metadata.is_dir() {
total += dir_size(&entry.path())?;
}
}
Ok(total)
}
#[derive(Debug, Deserialize)]
struct ListModelsResponse {
models: Vec<ModelSummary>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelSummary {
pub id: String,
pub family: String,
pub task: String,
pub parameters: u64,
pub description: String,
pub variants: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDetail {
pub id: String,
pub family: String,
pub task: String,
pub parameters: u64,
pub description: String,
pub default_variant: Option<String>,
pub variants: HashMap<String, VariantInfo>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct VariantInfo {
pub platform: String,
pub format: String,
pub quantization: String,
pub size_bytes: u64,
pub hf_repo: String,
pub file: String,
}
#[derive(Debug, Deserialize)]
struct ResolveResponse {
#[allow(dead_code)]
mask: String,
#[allow(dead_code)]
platform: String,
resolved: ResolvedVariant,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResolvedVariant {
pub hf_repo: String,
pub file: String,
pub download_url: String,
pub format: String,
pub quantization: String,
pub size_bytes: u64,
pub sha256: String,
#[serde(default)]
pub passthrough: bool,
#[serde(default)]
pub model_metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone)]
pub struct CacheStats {
pub total_size_bytes: u64,
pub model_count: usize,
pub cache_path: PathBuf,
}
impl CacheStats {
pub fn total_size_human(&self) -> String {
let bytes = self.total_size_bytes;
if bytes < 1024 {
format!("{} B", bytes)
} else if bytes < 1024 * 1024 {
format!("{:.1} KB", bytes as f64 / 1024.0)
} else if bytes < 1024 * 1024 * 1024 {
format!("{:.1} MB", bytes as f64 / (1024.0 * 1024.0))
} else {
format!("{:.2} GB", bytes as f64 / (1024.0 * 1024.0 * 1024.0))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn classify_download_source_recognises_r2_hosts() {
assert_eq!(
classify_download_source("https://r2.xybrid.dev/v1/kokoro/universal.xyb"),
"r2"
);
assert_eq!(
classify_download_source("https://abcd1234.r2.cloudflarestorage.com/bundles/x.xyb"),
"r2"
);
assert_eq!(
classify_download_source("https://pub-xxx.r2.dev/x.xyb"),
"r2"
);
}
#[test]
fn classify_download_source_recognises_huggingface_hosts() {
assert_eq!(
classify_download_source(
"https://huggingface.co/xybrid-ai/kokoro-82m/resolve/main/model.gguf"
),
"huggingface"
);
assert_eq!(
classify_download_source("https://hf.co/owner/repo/resolve/main/m.gguf"),
"huggingface"
);
}
#[test]
fn classify_download_source_falls_back_to_other() {
assert_eq!(
classify_download_source("https://cdn.example.com/m.gguf"),
"other"
);
assert_eq!(classify_download_source(""), "other");
}
#[test]
fn test_default_client() {
let client = RegistryClient::default_client().unwrap();
assert_eq!(client.api_urls.len(), 2);
assert_eq!(client.primary_url(), DEFAULT_REGISTRY_URL);
}
#[test]
fn build_client_header_default_binding_has_all_fields() {
let header = build_client_header_with_optout("rust", false)
.expect("header must be built when not opted out");
assert!(
header.starts_with("binding=rust;"),
"header should start with sanitized binding: {}",
header
);
assert!(
header.contains("sdk_version="),
"missing sdk_version: {}",
header
);
assert!(
header.contains("core_version="),
"missing core_version: {}",
header
);
assert!(
header.contains(&format!("platform={}", current_platform())),
"platform mismatch: {}",
header
);
assert!(
header.contains("backends="),
"missing backends key: {}",
header
);
}
#[test]
fn build_client_header_opt_out_returns_none() {
assert!(build_client_header_with_optout("rust", true).is_none());
}
#[test]
fn build_client_header_malformed_binding_falls_back_to_default() {
let header = build_client_header_with_optout("flutter; injected", false)
.expect("header must be built when not opted out");
assert!(
header.starts_with("binding=rust;"),
"malformed binding must collapse to DEFAULT_BINDING: {}",
header
);
assert!(
!header.contains("injected"),
"smuggled tokens must not appear in the header: {}",
header
);
}
#[test]
fn build_client_header_uppercase_binding_falls_back_to_default() {
let header = build_client_header_with_optout("FLUTTER", false).unwrap();
assert!(
header.starts_with("binding=rust;"),
"uppercase binding is not in the [a-z0-9_-] allowlist: {}",
header
);
}
#[test]
fn build_client_header_empty_binding_falls_back_to_default() {
let header = build_client_header_with_optout("", false).unwrap();
assert!(header.starts_with("binding=rust;"));
}
#[test]
fn build_client_header_accepts_known_bindings() {
for binding in ["rust", "flutter", "kotlin", "swift", "unity"] {
let header = build_client_header_with_optout(binding, false).unwrap();
let prefix = format!("binding={};", binding);
assert!(
header.starts_with(&prefix),
"binding `{}` should pass sanitization: {}",
binding,
header
);
}
}
#[test]
fn build_client_header_renders_empty_backends_list_without_panic() {
let header = build_client_header_with_optout("rust", false).unwrap();
assert!(
header.contains("backends="),
"header always carries the backends key: {}",
header
);
assert!(
!header.contains("backends=,"),
"leading comma in backends list: {}",
header
);
}
#[test]
fn sanitize_binding_accepts_alphanumerics_underscore_and_hyphen() {
assert_eq!(sanitize_binding("rust"), "rust");
assert_eq!(sanitize_binding("flutter"), "flutter");
assert_eq!(sanitize_binding("react-native"), "react-native");
assert_eq!(sanitize_binding("snake_case"), "snake_case");
assert_eq!(sanitize_binding("v2"), "v2");
}
#[test]
fn sanitize_binding_rejects_invalid_chars() {
assert_eq!(sanitize_binding(""), DEFAULT_BINDING);
assert_eq!(sanitize_binding("Flutter"), DEFAULT_BINDING);
assert_eq!(sanitize_binding("flutter app"), DEFAULT_BINDING);
assert_eq!(sanitize_binding("flutter;injected"), DEFAULT_BINDING);
assert_eq!(sanitize_binding("flu/tter"), DEFAULT_BINDING);
}
#[test]
fn test_single_url_client() {
let client = RegistryClient::with_url("https://custom.example.com").unwrap();
assert_eq!(client.api_urls.len(), 1);
assert_eq!(client.primary_url(), "https://custom.example.com");
}
#[test]
fn test_registry_urls_constant() {
assert_eq!(REGISTRY_URLS.len(), 2);
assert_eq!(REGISTRY_URLS[0], DEFAULT_REGISTRY_URL);
assert_eq!(REGISTRY_URLS[1], FALLBACK_REGISTRY_URL);
}
#[test]
fn test_cache_path() {
let client = RegistryClient::default_client().unwrap();
let resolved = ResolvedVariant {
hf_repo: "xybrid-ai/kokoro-82m".to_string(),
file: "universal.xyb".to_string(),
download_url: "https://example.com/bundle.xyb".to_string(),
format: "onnx".to_string(),
quantization: "fp16".to_string(),
size_bytes: 100000,
sha256: "abc123".to_string(),
passthrough: false,
model_metadata: None,
};
let path = client.get_cache_path(&resolved);
assert!(path.to_string_lossy().contains("kokoro-82m"));
assert!(path.to_string_lossy().contains("universal.xyb"));
}
#[test]
fn test_extraction_dir() {
let client = RegistryClient::default_client().unwrap();
let dir = client.extraction_dir("test-model");
assert!(dir.to_string_lossy().contains("extracted"));
assert!(dir.to_string_lossy().contains("test-model"));
}
#[test]
fn test_is_extracted_false_for_nonexistent() {
let client = RegistryClient::default_client().unwrap();
assert!(!client.is_extracted("nonexistent-model-12345"));
}
#[test]
fn test_resolve_offline_none_for_nonexistent() {
let client = RegistryClient::default_client().unwrap();
assert!(client
.resolve_offline("definitely-not-a-real-model-xyzzy-42")
.is_none());
}
#[test]
fn test_resolve_offline_matches_is_extracted() {
let client = RegistryClient::default_client().unwrap();
let mask = "nonexistent-model-12345";
assert_eq!(
client.resolve_offline(mask).is_some(),
client.is_extracted(mask)
);
}
#[test]
fn test_resolve_offline_returns_extraction_dir() {
let client = RegistryClient::default_client().unwrap();
let mask = "some-model";
let expected = client.extraction_dir(mask);
if let Some(actual) = client.resolve_offline(mask) {
assert_eq!(actual, expected);
}
}
#[test]
fn test_offline_error_does_not_trip_circuit_breaker() {
let client = RegistryClient::with_url("https://primary.example.invalid").unwrap();
let circuit = client.circuits[0].clone();
assert!(circuit.is_closed(), "breaker starts closed");
let mut op = |_url: &str| -> Result<ureq::Response, SdkError> {
Err(SdkError::Offline("simulated offline".to_string()))
};
let result =
client.execute_with_retry_for_url("https://primary.example.invalid", &circuit, &mut op);
assert!(matches!(result, Err(SdkError::Offline(_))));
assert_eq!(
circuit.failure_count(),
0,
"breaker must not count offline errors toward the failure threshold"
);
assert!(
circuit.is_closed(),
"breaker must stay closed after offline errors"
);
}
#[test]
fn test_offline_error_short_circuits_retry_loop() {
use std::sync::atomic::{AtomicU32, Ordering};
let client = RegistryClient::with_url("https://primary.example.invalid").unwrap();
let circuit = client.circuits[0].clone();
let call_count = AtomicU32::new(0);
let mut op = |_url: &str| -> Result<ureq::Response, SdkError> {
call_count.fetch_add(1, Ordering::SeqCst);
Err(SdkError::Offline("simulated offline".to_string()))
};
let result =
client.execute_with_retry_for_url("https://primary.example.invalid", &circuit, &mut op);
assert!(result.is_err());
assert_eq!(
call_count.load(Ordering::SeqCst),
1,
"offline errors must not be retried within a single URL"
);
}
#[test]
fn registry_client_default_binding_is_rust() {
let client = RegistryClient::default_client().unwrap();
assert_eq!(client.binding(), DEFAULT_BINDING);
}
#[test]
fn registry_client_with_binding_overrides_default() {
let client = RegistryClient::default_client()
.unwrap()
.with_binding("flutter");
assert_eq!(client.binding(), "flutter");
}
#[test]
fn apply_client_header_sets_header_when_not_opted_out() {
let client = RegistryClient::with_url("http://127.0.0.1:1").unwrap();
let req = client.agent.get("http://127.0.0.1:1/v1/models");
let req = client.apply_client_header_with_optout(req, false);
let header = req.header(CLIENT_HEADER_NAME);
assert!(header.is_some(), "header must be set when opt-out is false");
let value = header.unwrap();
assert!(value.contains("binding=rust;"), "value: {}", value);
assert!(value.contains("sdk_version="), "value: {}", value);
assert!(value.contains("core_version="), "value: {}", value);
assert!(value.contains("platform="), "value: {}", value);
assert!(value.contains("backends="), "value: {}", value);
}
#[test]
fn apply_client_header_omits_header_when_opted_out() {
let client = RegistryClient::with_url("http://127.0.0.1:1").unwrap();
let req = client.agent.get("http://127.0.0.1:1/v1/models");
let req = client.apply_client_header_with_optout(req, true);
assert_eq!(
req.header(CLIENT_HEADER_NAME),
None,
"no header on the wire when telemetry is opted out"
);
}
#[test]
fn apply_client_header_uses_configured_binding() {
let client = RegistryClient::with_url("http://127.0.0.1:1")
.unwrap()
.with_binding("flutter");
let req = client.agent.get("http://127.0.0.1:1/v1/models");
let req = client.apply_client_header_with_optout(req, false);
let value = req.header(CLIENT_HEADER_NAME).unwrap();
assert!(
value.starts_with("binding=flutter;"),
"configured binding must flow into the header: {}",
value
);
}
#[test]
fn metadata_calls_send_x_xybrid_client_header() {
use httpmock::prelude::*;
let expected = build_client_header_with_optout("flutter", false)
.expect("header must be built when not opted out");
let server = MockServer::start();
let list_mock = server.mock(|when, then| {
when.method(GET)
.path("/v1/models")
.header(CLIENT_HEADER_NAME, expected.as_str());
then.status(200)
.header("content-type", "application/json")
.body(r#"{"models": []}"#);
});
let get_mock = server.mock(|when, then| {
when.method(GET)
.path("/v1/models/test-model")
.header(CLIENT_HEADER_NAME, expected.as_str());
then.status(200).header("content-type", "application/json")
.body(
r#"{"id":"test-model","family":"test","task":"text-generation","parameters":1,"description":"d","default_variant":null,"variants":{}}"#,
);
});
let resolve_mock = server.mock(|when, then| {
when.method(GET)
.path("/v1/models/test-model/resolve")
.query_param_exists("platform")
.header(CLIENT_HEADER_NAME, expected.as_str());
then.status(200)
.header("content-type", "application/json")
.body(
r#"{"mask":"test-model","platform":"x","resolved":{"hf_repo":"o/r","file":"u.xyb","download_url":"https://x","format":"onnx","quantization":"fp32","size_bytes":1,"sha256":""}}"#,
);
});
let client = RegistryClient::with_url(server.base_url())
.unwrap()
.with_binding("flutter");
client.list_models().expect("list_models should succeed");
client
.get_model("test-model")
.expect("get_model should succeed");
client
.resolve("test-model", Some("apple-arm64-cpu"))
.expect("resolve should succeed");
list_mock.assert();
get_mock.assert();
resolve_mock.assert();
assert!(expected.starts_with("binding=flutter;"), "{}", expected);
assert!(expected.contains("sdk_version="), "{}", expected);
assert!(expected.contains("core_version="), "{}", expected);
assert!(expected.contains("platform="), "{}", expected);
assert!(expected.contains("backends="), "{}", expected);
}
#[test]
fn metadata_calls_omit_header_when_opt_out_helper_returns_none() {
use httpmock::prelude::*;
let server = MockServer::start();
let permissive = server.mock(|when, then| {
when.method(GET).path("/v1/models");
then.status(200)
.header("content-type", "application/json")
.body(r#"{"models": []}"#);
});
let with_header = server.mock(|when, then| {
when.method(GET)
.path("/v1/models")
.header_exists(CLIENT_HEADER_NAME);
then.status(599).body("UNEXPECTED HEADER");
});
let client = RegistryClient::with_url(server.base_url()).unwrap();
let url = format!("{}/v1/models", server.base_url());
let req = client.apply_client_header_with_optout(client.agent.get(&url), true);
let response = req.call().expect("request should reach mock server");
assert_eq!(response.status(), 200);
permissive.assert();
assert_eq!(
with_header.hits(),
0,
"X-Xybrid-Client header must NOT be sent when opted out"
);
}
}