use std::collections::BTreeMap;
use std::fmt;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ErrorCode {
InvalidRequest,
MultiQueryLimitExceeded,
Unauthorized,
Forbidden,
NotFound,
EmbeddingModelMismatch,
RunAborted,
RateLimited,
Internal,
ServiceUnavailable,
CloudUnreachable,
ModelsMissing,
TelemetrySchemaInvalid,
}
impl ErrorCode {
#[must_use]
pub const fn as_str(self) -> &'static str {
match self {
Self::InvalidRequest => "invalid_request",
Self::MultiQueryLimitExceeded => "multi_query_limit_exceeded",
Self::Unauthorized => "unauthorized",
Self::Forbidden => "forbidden",
Self::NotFound => "not_found",
Self::EmbeddingModelMismatch => "embedding_model_mismatch",
Self::RunAborted => "run_aborted",
Self::RateLimited => "rate_limited",
Self::Internal => "internal",
Self::ServiceUnavailable => "service_unavailable",
Self::CloudUnreachable => "cloud_unreachable",
Self::ModelsMissing => "models_missing",
Self::TelemetrySchemaInvalid => "telemetry_schema_invalid",
}
}
#[must_use]
pub const fn http_status(self) -> u16 {
match self {
Self::InvalidRequest | Self::MultiQueryLimitExceeded => 400,
Self::Unauthorized => 401,
Self::Forbidden => 403,
Self::NotFound => 404,
Self::EmbeddingModelMismatch | Self::RunAborted => 409,
Self::RateLimited => 429,
Self::ServiceUnavailable | Self::CloudUnreachable | Self::ModelsMissing => 503,
Self::Internal | Self::TelemetrySchemaInvalid => 500,
}
}
}
impl fmt::Display for ErrorCode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.write_str(self.as_str())
}
}
pub type ErrorContext = BTreeMap<String, serde_json::Value>;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
pub struct Error {
pub code: ErrorCode,
pub message: String,
pub remediation: String,
#[serde(default, skip_serializing_if = "BTreeMap::is_empty")]
pub context: ErrorContext,
}
impl Error {
#[must_use]
pub const fn builder(code: ErrorCode) -> ErrorBuilder {
ErrorBuilder {
code,
message: None,
remediation: None,
context: ErrorContext::new(),
}
}
#[must_use]
pub const fn http_status(&self) -> u16 {
self.code.http_status()
}
}
impl fmt::Display for Error {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {}", self.code, self.message)
}
}
impl std::error::Error for Error {}
#[derive(Debug)]
pub struct ErrorBuilder {
code: ErrorCode,
message: Option<String>,
remediation: Option<String>,
context: ErrorContext,
}
impl ErrorBuilder {
#[must_use]
pub fn message(mut self, msg: impl Into<String>) -> Self {
self.message = Some(msg.into());
self
}
#[must_use]
pub fn remediation(mut self, rem: impl Into<String>) -> Self {
self.remediation = Some(rem.into());
self
}
#[must_use]
pub fn context(mut self, key: impl Into<String>, value: impl Into<serde_json::Value>) -> Self {
self.context.insert(key.into(), value.into());
self
}
#[must_use]
pub fn build(self) -> Error {
Error {
code: self.code,
message: self
.message
.expect("Error::builder requires .message() — every error must have a summary"),
remediation: self.remediation.expect(
"Error::builder requires .remediation() — every error must point at a next step",
),
context: self.context,
}
}
}
pub type Result<T> = std::result::Result<T, Error>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn code_round_trips_through_json() {
for code in [
ErrorCode::InvalidRequest,
ErrorCode::EmbeddingModelMismatch,
ErrorCode::RateLimited,
ErrorCode::ModelsMissing,
] {
let s = serde_json::to_string(&code).unwrap();
let back: ErrorCode = serde_json::from_str(&s).unwrap();
assert_eq!(code, back);
}
}
#[test]
fn wire_shape_is_stable() {
let err = Error::builder(ErrorCode::EmbeddingModelMismatch)
.message("client model bge-small@1 does not match corpus model bge-base@1")
.remediation("run `mnm models pull` to fetch bge-base@1")
.context("corpus_model", "bge-base-en-v1.5@1")
.context("client_model", "bge-small-en-v1.5@1")
.build();
let v: serde_json::Value = serde_json::to_value(&err).unwrap();
assert_eq!(v["code"], "embedding_model_mismatch");
assert!(v["message"].is_string());
assert!(v["remediation"].is_string());
assert_eq!(v["context"]["corpus_model"], "bge-base-en-v1.5@1");
}
#[test]
fn http_status_mapping() {
assert_eq!(ErrorCode::InvalidRequest.http_status(), 400);
assert_eq!(ErrorCode::Unauthorized.http_status(), 401);
assert_eq!(ErrorCode::EmbeddingModelMismatch.http_status(), 409);
assert_eq!(ErrorCode::RateLimited.http_status(), 429);
assert_eq!(ErrorCode::ServiceUnavailable.http_status(), 503);
assert_eq!(ErrorCode::Internal.http_status(), 500);
}
#[test]
#[should_panic(expected = "requires .message()")]
fn builder_panics_without_message() {
let _ = Error::builder(ErrorCode::Internal)
.remediation("file an issue")
.build();
}
#[test]
#[should_panic(expected = "requires .remediation()")]
fn builder_panics_without_remediation() {
let _ = Error::builder(ErrorCode::Internal)
.message("something broke")
.build();
}
#[test]
fn empty_context_is_elided_from_wire() {
let err = Error::builder(ErrorCode::NotFound)
.message("source not found")
.remediation("check `mnm sources list`")
.build();
let v: serde_json::Value = serde_json::to_value(&err).unwrap();
assert!(v.get("context").is_none(), "empty context must be elided");
}
}