#![cfg(feature = "gemma-default")]
use std::path::PathBuf;
use atomr_core::actor::ActorRef;
use atomr_infer_core::deployment::Deployment;
use atomr_infer_core::error::{InferenceError, InferenceResult};
use atomr_infer_core::runtime::RuntimeConfig;
use atomr_infer_runtime::{DeploymentManagerActor, DeploymentManagerMsg};
use crate::probe::{probe, ProbeResult};
use crate::VllmConfig;
pub const DEFAULT_DEPLOYMENT_NAME: &str = "gemma-local";
pub const DEFAULT_MODEL_ID: &str = "google/gemma-4-E4B-it";
pub const SUPPORTED_VARIANTS: &[&str] = &[
"google/gemma-4-E2B",
"google/gemma-4-E2B-it",
"google/gemma-4-E4B",
"google/gemma-4-E4B-it",
];
pub fn min_vram_gb(model_id: &str) -> Option<f32> {
match model_id {
"google/gemma-4-E2B" | "google/gemma-4-E2B-it" => Some(2.5),
"google/gemma-4-E4B" | "google/gemma-4-E4B-it" => Some(4.5),
_ => None,
}
}
pub fn min_disk_gb(model_id: &str) -> Option<f32> {
match model_id {
"google/gemma-4-E2B" | "google/gemma-4-E2B-it" => Some(4.0),
"google/gemma-4-E4B" | "google/gemma-4-E4B-it" => Some(7.0),
_ => None,
}
}
pub fn fallback_variant(model_id: &str) -> Option<&'static str> {
match model_id {
"google/gemma-4-E4B-it" => Some("google/gemma-4-E2B-it"),
"google/gemma-4-E4B" => Some("google/gemma-4-E2B"),
_ => None,
}
}
pub fn validate_variant(model_id: &str) -> InferenceResult<()> {
if SUPPORTED_VARIANTS.contains(&model_id) {
Ok(())
} else {
Err(InferenceError::BadRequest {
message: format!(
"unsupported Gemma variant `{model_id}` — supported: {}",
SUPPORTED_VARIANTS.join(", ")
),
})
}
}
#[derive(Debug, Clone)]
pub struct GemmaDefaults {
pub model_id: String,
pub deployment_name: String,
pub cache_dir: Option<PathBuf>,
pub gpu_memory_utilization: f32,
pub max_model_len: Option<u32>,
pub auto_provision: bool,
}
impl Default for GemmaDefaults {
fn default() -> Self {
Self {
model_id: DEFAULT_MODEL_ID.into(),
deployment_name: DEFAULT_DEPLOYMENT_NAME.into(),
cache_dir: None,
gpu_memory_utilization: 0.5,
max_model_len: None,
auto_provision: true,
}
}
}
impl GemmaDefaults {
pub fn from_env() -> Self {
let mut cfg = Self::default();
if let Some(v) = env_string("ATOMR_INFER_GEMMA_AUTO") {
cfg.auto_provision = !matches!(
v.to_ascii_lowercase().as_str(),
"0" | "false" | "no" | "off" | "skip" | "skip-quietly"
);
}
if let Some(v) = env_string("ATOMR_INFER_GEMMA_MODEL") {
cfg.model_id = v;
}
if let Some(v) = env_string("ATOMR_INFER_GEMMA_DEPLOYMENT") {
cfg.deployment_name = v;
}
if let Some(v) = env_string("ATOMR_INFER_GEMMA_GPU_UTIL") {
if let Ok(f) = v.parse::<f32>() {
cfg.gpu_memory_utilization = f;
}
}
if let Some(v) = env_string("ATOMR_INFER_GEMMA_MAX_LEN") {
if let Ok(u) = v.parse::<u32>() {
cfg.max_model_len = Some(u);
}
}
cfg
}
}
#[derive(Debug)]
pub enum ProvisionOutcome {
Ready { deployment_name: String },
Skipped { reason: String, hint: String },
}
pub async fn provision_if_ready(
manager: &ActorRef<DeploymentManagerMsg>,
cfg: &GemmaDefaults,
) -> InferenceResult<ProvisionOutcome> {
validate_variant(&cfg.model_id)?;
let min_vram = min_vram_gb(&cfg.model_id).unwrap_or(4.5);
let min_disk = min_disk_gb(&cfg.model_id).unwrap_or(7.0);
let fallback = fallback_variant(&cfg.model_id);
match probe(&cfg.model_id, min_vram, min_disk, fallback) {
ProbeResult::Skipped { reason, hint } => {
return Ok(ProvisionOutcome::Skipped { reason, hint })
}
ProbeResult::Error(e) => return Err(e),
ProbeResult::Ready { vram_free_gb, hf_cache } => {
tracing::info!(
model = %cfg.model_id,
deployment = %cfg.deployment_name,
vram_free_gb,
hf_cache = %hf_cache.hub_cache.display(),
"probe ok — provisioning Gemma deployment"
);
}
}
let vllm_cfg = VllmConfig {
model: cfg.model_id.clone(),
tensor_parallel_size: 1,
dtype: "auto".into(),
gpu_memory_utilization: Some(cfg.gpu_memory_utilization),
max_model_len: cfg.max_model_len,
hf_cache_dir: cfg.cache_dir.clone(),
enforce_eager: Some(true),
enable_prefix_caching: None,
enable_chunked_prefill: None,
max_num_seqs: Some(16),
block_size: None,
quantization: None,
limit_mm_per_prompt: None,
cpu_offload_gb: Some(4),
};
let runtime_config = serde_json::to_value(&vllm_cfg)
.map(RuntimeConfig::Vllm)
.map_err(|e| {
InferenceError::Internal(format!("gemma defaults: serialise VllmConfig: {e}"))
})?;
let deployment = Deployment {
name: cfg.deployment_name.clone(),
model: cfg.model_id.clone(),
runtime: Some(atomr_infer_core::runtime::RuntimeKind::Vllm),
runtime_config: Some(runtime_config),
gpus: Some(1),
replicas: 1,
serving: Default::default(),
budget: None,
idempotent: true,
};
let (tx, rx) = tokio::sync::oneshot::channel();
manager.tell(DeploymentManagerMsg::Apply {
deployment,
reply: tx,
});
match rx.await {
Ok(Ok(())) => Ok(ProvisionOutcome::Ready {
deployment_name: cfg.deployment_name.clone(),
}),
Ok(Err(e)) => Err(e),
Err(_) => Err(InferenceError::Internal(
"gemma defaults: deployment manager dropped reply channel".into(),
)),
}
}
fn env_string(var: &str) -> Option<String> {
std::env::var(var)
.ok()
.map(|s| s.trim().to_string())
.filter(|s| !s.is_empty())
}
#[allow(dead_code)]
type _ManagerType = DeploymentManagerActor;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn defaults_pick_e4b_it() {
let d = GemmaDefaults::default();
assert_eq!(d.model_id, "google/gemma-4-E4B-it");
assert_eq!(d.deployment_name, "gemma-local");
assert_eq!(d.gpu_memory_utilization, 0.5);
assert!(d.auto_provision);
}
#[test]
fn validate_variant_accepts_all_four() {
for v in SUPPORTED_VARIANTS {
assert!(validate_variant(v).is_ok(), "{v} should be supported");
}
}
#[test]
fn validate_variant_rejects_unknown() {
assert!(matches!(
validate_variant("google/some-other-model"),
Err(InferenceError::BadRequest { .. })
));
}
#[test]
fn fallback_e4b_to_e2b() {
assert_eq!(
fallback_variant("google/gemma-4-E4B-it"),
Some("google/gemma-4-E2B-it")
);
assert_eq!(
fallback_variant("google/gemma-4-E4B"),
Some("google/gemma-4-E2B")
);
assert_eq!(fallback_variant("google/gemma-4-E2B-it"), None);
}
fn env_lock() -> std::sync::MutexGuard<'static, ()> {
static LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
LOCK.lock().unwrap_or_else(|p| p.into_inner())
}
#[test]
fn from_env_respects_skip_quietly() {
let _g = env_lock();
std::env::set_var("ATOMR_INFER_GEMMA_AUTO", "skip-quietly");
let d = GemmaDefaults::from_env();
assert!(!d.auto_provision);
std::env::remove_var("ATOMR_INFER_GEMMA_AUTO");
}
#[test]
fn from_env_overrides_model_id() {
let _g = env_lock();
std::env::set_var("ATOMR_INFER_GEMMA_MODEL", "google/gemma-4-E2B-it");
let d = GemmaDefaults::from_env();
assert_eq!(d.model_id, "google/gemma-4-E2B-it");
std::env::remove_var("ATOMR_INFER_GEMMA_MODEL");
}
#[test]
fn vram_floors_match_table() {
assert_eq!(min_vram_gb("google/gemma-4-E2B-it"), Some(2.5));
assert_eq!(min_vram_gb("google/gemma-4-E4B-it"), Some(4.5));
assert_eq!(min_vram_gb("unknown"), None);
}
}