use thiserror::Error;
pub type NlResult<T> = Result<T, NlError>;
#[derive(Error, Debug)]
pub enum NlError {
#[error("Preprocessing failed: {0}")]
Preprocess(#[from] PreprocessError),
#[error("Entity extraction failed: {0}")]
Extractor(#[from] ExtractorError),
#[error("Classification failed: {0}")]
Classifier(ClassifierError),
#[error("Assembly failed: {0}")]
Assembler(#[from] AssemblerError),
#[error("Validation failed: {0}")]
Validator(#[from] ValidatorError),
#[error("Cache error: {0}")]
Cache(#[from] CacheError),
#[error("Configuration error: {0}")]
Config(String),
#[error("I/O error: {0}")]
Io(#[from] std::io::Error),
#[error("Model directory not found: {0}")]
ModelDirNotFound(String),
#[error("Model file checksum mismatch for {file}: expected {expected}, got {actual}")]
ChecksumMismatch {
file: String,
expected: String,
actual: String,
},
#[error("Model checksums file is missing from model directory")]
ChecksumsMissing,
#[error("File listed in checksums is missing from model directory: {0}")]
ChecksummedFileMissing(String),
#[error("ONNX Runtime is not available: {hint}")]
OnnxRuntimeMissing {
hint: String,
},
#[error("Model manifest SHA-256 mismatch for {file}: expected {expected}, got {actual}")]
ManifestSha256Mismatch {
file: String,
expected: String,
actual: String,
},
#[error("Model manifest parse failed: {0}")]
ManifestParseFailed(#[from] serde_json::Error),
#[error("Model download is disabled by configuration")]
DownloadDisabled,
#[error("Model download failed: {0}")]
DownloadFailed(String),
}
impl From<ClassifierError> for NlError {
fn from(err: ClassifierError) -> Self {
match err {
ClassifierError::OnnxRuntimeMissing { hint } => NlError::OnnxRuntimeMissing { hint },
other => NlError::Classifier(other),
}
}
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum PreprocessError {
#[error("Input too long: {len} bytes (max: {max})")]
InputTooLong { len: usize, max: usize },
#[error("Input is empty or contains only whitespace")]
EmptyInput,
#[error("Suspicious character detected: possible homoglyph attack")]
HomoglyphDetected,
#[error("Invalid UTF-8 encoding")]
InvalidUtf8,
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ExtractorError {
#[error("No symbol or pattern found in query")]
NoSymbolFound,
#[error("Ambiguous symbol reference: multiple interpretations possible")]
AmbiguousSymbol,
#[error("Unknown language: {0}")]
UnknownLanguage(String),
#[error("Unknown symbol kind: {0}")]
UnknownKind(String),
#[error("Pattern compilation failed: {0}")]
RegexError(String),
}
#[derive(Error, Debug)]
pub enum ClassifierError {
#[error("Model not found at: {0}")]
ModelNotFound(String),
#[error("Model checksum mismatch for {file}: expected {expected}, got {actual}")]
ChecksumMismatch {
file: String,
expected: String,
actual: String,
},
#[error("Model checksums file is missing from model directory")]
ChecksumsMissing,
#[error("File listed in checksums is missing from model directory: {0}")]
ChecksummedFileMissing(String),
#[error("Tokenization failed: {0}")]
TokenizationFailed(String),
#[error("ONNX Runtime error: {0}")]
OnnxError(String),
#[error("Model manifest integrity anchor invalid: {0}")]
ManifestAnchorInvalid(String),
#[error("ONNX Runtime is not available: {hint}")]
OnnxRuntimeMissing {
hint: String,
},
#[error("Model version {model_version} incompatible with sqry-nl {crate_version}")]
VersionMismatch {
model_version: String,
crate_version: String,
},
#[error("Classification timed out after {timeout_ms}ms")]
Timeout { timeout_ms: u64 },
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum AssemblerError {
#[error("Missing required symbol for this command type")]
MissingSymbol,
#[error("Trace-path requires both 'from' and 'to' symbols")]
MissingTracePath,
#[error("Cannot assemble command: intent is ambiguous")]
AmbiguousIntent,
#[error("Generated command too long: {len} chars (max: {max})")]
CommandTooLong { len: usize, max: usize },
#[error("No template found for intent: {0}")]
NoTemplate(String),
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ValidatorError {
#[error("Command rejected: doesn't match any allowed template")]
TemplateMismatch,
#[error("Command rejected: contains shell metacharacters")]
MetacharDetected,
#[error("Command rejected: contains environment variable")]
EnvVarDetected,
#[error("Command rejected: path traversal detected")]
PathTraversal,
#[error("Command rejected: absolute paths not allowed")]
AbsolutePath,
#[error("Command rejected: write operations not allowed via NL")]
WriteOperation,
#[error("Command rejected: exceeds maximum length")]
CommandTooLong,
}
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum CacheError {
#[error("Cache is disabled")]
Disabled,
#[error("Cache entry has expired")]
Expired,
#[error("Failed to generate cache key: {0}")]
KeyGenerationFailed(String),
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_error_display() {
let err = PreprocessError::InputTooLong {
len: 5000,
max: 4096,
};
assert!(err.to_string().contains("5000"));
assert!(err.to_string().contains("4096"));
}
#[test]
fn test_error_conversion() {
let preprocess_err = PreprocessError::EmptyInput;
let nl_err: NlError = preprocess_err.into();
assert!(matches!(nl_err, NlError::Preprocess(_)));
}
#[test]
fn test_errors_implement_std_error() {
fn assert_error<T: std::error::Error>() {}
assert_error::<NlError>();
assert_error::<PreprocessError>();
assert_error::<ExtractorError>();
assert_error::<ClassifierError>();
assert_error::<AssemblerError>();
assert_error::<ValidatorError>();
assert_error::<CacheError>();
}
}