use ollama_rs::error::OllamaError;
use reqwest::StatusCode;
use thiserror::Error;
#[derive(Debug, Error, Clone, PartialEq, Eq)]
pub enum RuntimeError {
#[error("network error")]
Network,
#[error("request timed out")]
Timeout,
#[error("unauthorized")]
Unauthorized,
#[error("not found")]
NotFound,
#[error("model not found: {0}")]
ModelNotFound(String),
#[error("server error")]
ServerError,
#[error("{0}")]
Other(String),
}
pub type Result<T> = std::result::Result<T, RuntimeError>;
pub(crate) fn runtime_error_is_retryable(err: &RuntimeError) -> bool {
matches!(
err,
RuntimeError::Network | RuntimeError::Timeout | RuntimeError::ServerError
)
}
fn looks_like_model_missing_message(msg: &str) -> bool {
let m = msg.to_ascii_lowercase();
m.contains("model")
&& (m.contains("not found")
|| m.contains("unknown model")
|| m.contains("does not exist")
|| m.contains("pull"))
}
pub(crate) fn ollama_error_is_retryable(err: &OllamaError) -> bool {
match err {
OllamaError::ReqwestError(e) => reqwest_error_is_retryable(e),
OllamaError::JsonError(_)
| OllamaError::InternalError(_)
| OllamaError::ToolCallError(_)
| OllamaError::Other(_) => false,
}
}
fn reqwest_error_is_retryable(err: &reqwest::Error) -> bool {
if err.is_timeout() || err.is_connect() {
return true;
}
if let Some(status) = err.status() {
return status.is_server_error();
}
false
}
pub(crate) fn map_ollama_error(err: OllamaError) -> RuntimeError {
match err {
OllamaError::ReqwestError(e) => map_reqwest_error(e),
OllamaError::JsonError(e) => RuntimeError::Other(e.to_string()),
OllamaError::InternalError(e) => {
if looks_like_model_missing_message(&e.message) {
RuntimeError::ModelNotFound(e.message)
} else {
RuntimeError::Other(e.message)
}
}
OllamaError::ToolCallError(e) => RuntimeError::Other(e.to_string()),
OllamaError::Other(s) => map_ollama_other_string(s),
}
}
fn map_ollama_other_string(s: String) -> RuntimeError {
if looks_like_model_missing_message(&s) {
return RuntimeError::ModelNotFound(s);
}
if let Ok(v) = serde_json::from_str::<serde_json::Value>(&s) {
if let Some(err) = v.get("error").and_then(|e| e.as_str()) {
if looks_like_model_missing_message(err) {
return RuntimeError::ModelNotFound(err.to_string());
}
return RuntimeError::Other(err.to_string());
}
}
RuntimeError::Other(s)
}
fn map_reqwest_error(err: reqwest::Error) -> RuntimeError {
if err.is_timeout() {
return RuntimeError::Timeout;
}
if err.is_connect() {
return RuntimeError::Network;
}
if let Some(status) = err.status() {
return map_http_status(status);
}
RuntimeError::Network
}
fn map_http_status(status: StatusCode) -> RuntimeError {
if status.is_server_error() {
return RuntimeError::ServerError;
}
match status {
StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => RuntimeError::Unauthorized,
StatusCode::NOT_FOUND => RuntimeError::NotFound,
_ => RuntimeError::Other(status.to_string()),
}
}