use pyo3::prelude::*;
use pyo3::PyErr;
use oxillama_arch::ArchError;
use oxillama_runtime::RuntimeError;
pyo3::create_exception!(oxillama_py, OxiLlamaError, pyo3::exceptions::PyException);
pyo3::create_exception!(oxillama_py, LoadError, OxiLlamaError);
pyo3::create_exception!(oxillama_py, GenerateError, OxiLlamaError);
pyo3::create_exception!(oxillama_py, TokenizerError, OxiLlamaError);
pyo3::create_exception!(oxillama_py, GrammarError, OxiLlamaError);
pyo3::create_exception!(oxillama_py, QuantError, OxiLlamaError);
pyo3::create_exception!(oxillama_py, KvCacheFullError, OxiLlamaError);
pub fn register_exceptions(m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add("OxiLlamaError", m.py().get_type::<OxiLlamaError>())?;
m.add("LoadError", m.py().get_type::<LoadError>())?;
m.add("GenerateError", m.py().get_type::<GenerateError>())?;
m.add("TokenizerError", m.py().get_type::<TokenizerError>())?;
m.add("GrammarError", m.py().get_type::<GrammarError>())?;
m.add("QuantError", m.py().get_type::<QuantError>())?;
m.add("KvCacheFullError", m.py().get_type::<KvCacheFullError>())?;
Ok(())
}
pub fn runtime_to_py(err: RuntimeError) -> PyErr {
match err {
RuntimeError::ModelNotLoaded => {
GenerateError::new_err("Model not loaded — call load_model() first")
}
RuntimeError::TokenizerNotAvailable => TokenizerError::new_err(
"Tokenizer not available — rebuild with the `tokenizer-wasm` feature enabled",
),
RuntimeError::ModelLoadError { message } => {
LoadError::new_err(format!("Model load error: {message}"))
}
RuntimeError::TokenizerError { message } => {
TokenizerError::new_err(format!("Tokenizer error: {message}"))
}
RuntimeError::SamplingError { message } => {
GenerateError::new_err(format!("Sampling error: {message}"))
}
RuntimeError::KvCacheFull { max_ctx } => KvCacheFullError::new_err(format!(
"KV cache full: maximum context length {max_ctx} reached"
)),
RuntimeError::Cancelled => GenerateError::new_err("Generation cancelled"),
RuntimeError::Arch(arch_err) => arch_to_py(arch_err),
RuntimeError::Gguf(gguf_err) => LoadError::new_err(format!("GGUF parse error: {gguf_err}")),
RuntimeError::Quant(quant_err) => {
QuantError::new_err(format!("Quantization error: {quant_err}"))
}
RuntimeError::Io(io_err) => LoadError::new_err(format!("I/O error: {io_err}")),
RuntimeError::Grammar(grammar_err) => {
GrammarError::new_err(format!("Grammar error: {grammar_err}"))
}
RuntimeError::AttentionError { message } => {
GenerateError::new_err(format!("Attention error: {message}"))
}
RuntimeError::SnapshotIncompatible { detail } => {
GenerateError::new_err(format!("Snapshot incompatible: {detail}"))
}
RuntimeError::ModelFingerprintMismatch {
expected,
found,
detail,
} => LoadError::new_err(format!(
"Model fingerprint mismatch — expected {expected}, found {found}: {detail}"
)),
RuntimeError::OffloadEof {
offset,
needed,
available,
} => LoadError::new_err(format!(
"Offload I/O error: unexpected EOF at offset {offset}, needed {needed} bytes, {available} available"
)),
RuntimeError::TensorNotFound(name) => {
LoadError::new_err(format!("Tensor not found in weight map: '{name}'"))
}
RuntimeError::LockPoisoned => {
GenerateError::new_err("Internal error: lock poisoned")
}
RuntimeError::SpecSnapshotIncompatible(detail) => {
GenerateError::new_err(format!("Speculative snapshot incompatible: {detail}"))
}
RuntimeError::EmptySequence => {
GenerateError::new_err("Input tokenizes to an empty sequence — provide at least one token")
}
}
}
pub fn arch_to_py(err: ArchError) -> PyErr {
match err {
ArchError::Gguf(gguf_err) => {
LoadError::new_err(format!("GGUF error in arch layer: {gguf_err}"))
}
ArchError::Quant(quant_err) => {
QuantError::new_err(format!("Quantization error in arch layer: {quant_err}"))
}
ArchError::MissingTensor { name } => {
LoadError::new_err(format!("Missing tensor in model: '{name}'"))
}
ArchError::UnknownArchitecture { arch_id } => {
LoadError::new_err(format!("Unknown architecture: '{arch_id}'"))
}
ArchError::ConfigMismatch {
param,
expected,
got,
} => LoadError::new_err(format!(
"Config mismatch for '{param}': expected {expected}, got {got}"
)),
ArchError::TensorShapeMismatch {
tensor,
expected,
got,
} => LoadError::new_err(format!(
"Tensor shape mismatch for '{tensor}': expected {expected:?}, got {got:?}"
)),
ArchError::NotSupported { detail } => {
GenerateError::new_err(format!("Operation not supported: {detail}"))
}
ArchError::ForwardPassError { layer, message } => {
GenerateError::new_err(format!("Forward pass error at layer {layer}: {message}"))
}
ArchError::InvalidShape {
name,
expected,
got,
} => LoadError::new_err(format!(
"Invalid shape for '{name}': expected {expected:?}, got {got:?}"
)),
ArchError::InvalidConfig { detail } => {
LoadError::new_err(format!("Invalid configuration: {detail}"))
}
ArchError::UnsupportedOperation { message } => {
GenerateError::new_err(format!("Unsupported operation: {message}"))
}
ArchError::LoraIncompatible { detail } => {
LoadError::new_err(format!("LoRA adapter incompatible: {detail}"))
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_runtime_error_variants_map_without_panic() {
use oxillama_runtime::sampling::grammar::GrammarError;
let variants: Vec<RuntimeError> = vec![
RuntimeError::ModelNotLoaded,
RuntimeError::TokenizerNotAvailable,
RuntimeError::TokenizerError {
message: "test".to_string(),
},
RuntimeError::SamplingError {
message: "test".to_string(),
},
RuntimeError::KvCacheFull { max_ctx: 1024 },
RuntimeError::ModelLoadError {
message: "test".to_string(),
},
RuntimeError::Cancelled,
RuntimeError::Gguf(oxillama_gguf::GgufError::InvalidMagic { magic: 0 }),
RuntimeError::Quant(oxillama_quant::QuantError::UnsupportedType {
quant_type: "Q99".to_string(),
}),
RuntimeError::Io(std::io::Error::new(std::io::ErrorKind::NotFound, "test")),
RuntimeError::Grammar(GrammarError::Stuck),
RuntimeError::OffloadEof {
offset: 0,
needed: 4,
available: 2,
},
RuntimeError::TensorNotFound("blk.0.attn_q.weight".to_string()),
RuntimeError::LockPoisoned,
RuntimeError::SpecSnapshotIncompatible("test detail".to_string()),
RuntimeError::EmptySequence,
];
for variant in variants {
let _ = runtime_to_py(variant);
}
}
#[test]
fn test_all_arch_error_variants_map_without_panic() {
let variants: Vec<ArchError> = vec![
ArchError::Gguf(oxillama_gguf::GgufError::InvalidMagic { magic: 0 }),
ArchError::Quant(oxillama_quant::QuantError::UnsupportedType {
quant_type: "Q99".to_string(),
}),
ArchError::MissingTensor {
name: "blk.0.attn_q.weight".to_string(),
},
ArchError::UnknownArchitecture {
arch_id: "foo".to_string(),
},
ArchError::ConfigMismatch {
param: "p".to_string(),
expected: "a".to_string(),
got: "b".to_string(),
},
ArchError::TensorShapeMismatch {
tensor: "t".to_string(),
expected: vec![2, 3],
got: vec![3, 2],
},
ArchError::NotSupported {
detail: "x".to_string(),
},
ArchError::ForwardPassError {
layer: 0,
message: "err".to_string(),
},
ArchError::InvalidShape {
name: "expert.gate".to_string(),
expected: vec![8, 4],
got: vec![3],
},
ArchError::InvalidConfig {
detail: "top_k must be >= 1".to_string(),
},
ArchError::LoraIncompatible {
detail: "rank mismatch".to_string(),
},
];
for variant in variants {
let _ = arch_to_py(variant);
}
}
#[test]
fn test_model_not_loaded_message() {
let rust_msg = RuntimeError::ModelNotLoaded.to_string();
assert!(
rust_msg.to_lowercase().contains("model"),
"Rust error message should mention 'model', got: {rust_msg}"
);
}
#[test]
fn test_model_load_error_contains_message() {
let rust_msg = RuntimeError::ModelLoadError {
message: "missing_xyz.gguf".to_string(),
}
.to_string();
assert!(
rust_msg.contains("missing_xyz.gguf"),
"Rust error message should contain original cause, got: {rust_msg}"
);
}
#[test]
fn test_kv_cache_full_message() {
let rust_msg = RuntimeError::KvCacheFull { max_ctx: 9999 }.to_string();
assert!(
rust_msg.contains("9999"),
"Rust error message should contain max_ctx=9999, got: {rust_msg}"
);
}
}