use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Intent {
SymbolQuery,
TextSearch,
TracePath,
FindCallers,
FindCallees,
Visualize,
IndexStatus,
Ambiguous,
}
impl Intent {
pub const NUM_CLASSES: usize = 8;
#[must_use]
pub fn from_index(idx: usize) -> Self {
match idx {
0 => Self::SymbolQuery,
1 => Self::TextSearch,
2 => Self::TracePath,
3 => Self::FindCallers,
4 => Self::FindCallees,
5 => Self::Visualize,
6 => Self::IndexStatus,
_ => Self::Ambiguous,
}
}
#[must_use]
pub const fn to_index(self) -> usize {
match self {
Self::SymbolQuery => 0,
Self::TextSearch => 1,
Self::TracePath => 2,
Self::FindCallers => 3,
Self::FindCallees => 4,
Self::Visualize => 5,
Self::IndexStatus => 6,
Self::Ambiguous => 7,
}
}
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::SymbolQuery => "symbol_query",
Self::TextSearch => "text_search",
Self::TracePath => "trace_path",
Self::FindCallers => "find_callers",
Self::FindCallees => "find_callees",
Self::Visualize => "visualize",
Self::IndexStatus => "index_status",
Self::Ambiguous => "ambiguous",
}
}
#[must_use]
pub const fn description(&self) -> &'static str {
match self {
Self::SymbolQuery => "Search for symbols by name or pattern",
Self::TextSearch => "Search for text patterns in code",
Self::TracePath => "Find call path between two symbols",
Self::FindCallers => "Find all places that call a symbol",
Self::FindCallees => "Find all symbols called by a function",
Self::Visualize => "Generate a diagram of code relationships",
Self::IndexStatus => "Check the status of the code index",
Self::Ambiguous => "Intent unclear, needs clarification",
}
}
}
impl std::fmt::Display for Intent {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum ValidationStatus {
Valid,
RejectedMetachar,
RejectedPathTraversal,
RejectedWriteMode,
RejectedEnvVar,
RejectedTooLong,
RejectedUnknown,
}
impl ValidationStatus {
#[must_use]
pub const fn is_valid(&self) -> bool {
matches!(self, Self::Valid)
}
#[must_use]
pub const fn rejection_reason(&self) -> Option<&'static str> {
match self {
Self::Valid => None,
Self::RejectedMetachar => Some("Contains shell metacharacters"),
Self::RejectedPathTraversal => Some("Contains path traversal"),
Self::RejectedWriteMode => Some("Attempts write operation"),
Self::RejectedEnvVar => Some("Contains environment variable"),
Self::RejectedTooLong => Some("Exceeds maximum command length"),
Self::RejectedUnknown => Some("Doesn't match allowed command patterns"),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum PredicateType {
Impl,
Duplicates,
Circular,
Unused,
}
impl PredicateType {
#[must_use]
pub const fn as_prefix(&self) -> &'static str {
match self {
Self::Impl => "impl:",
Self::Duplicates => "duplicates:",
Self::Circular => "circular:",
Self::Unused => "unused:",
}
}
}
impl std::fmt::Display for PredicateType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_prefix())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum Visibility {
Public,
Private,
}
impl Visibility {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Public => "public",
Self::Private => "private",
}
}
}
impl std::fmt::Display for Visibility {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum SymbolKind {
Function,
Class,
Struct,
Enum,
Trait,
Interface,
Method,
Module,
Constant,
Variable,
TypeAlias,
}
impl SymbolKind {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Function => "function",
Self::Class => "class",
Self::Struct => "struct",
Self::Enum => "enum",
Self::Trait => "trait",
Self::Interface => "interface",
Self::Method => "method",
Self::Module => "module",
Self::Constant => "constant",
Self::Variable => "variable",
Self::TypeAlias => "type_alias",
}
}
}
impl std::fmt::Display for SymbolKind {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
#[serde(rename_all = "snake_case")]
pub enum OutputFormat {
Mermaid,
Dot,
Json,
}
impl OutputFormat {
#[must_use]
pub const fn as_str(&self) -> &'static str {
match self {
Self::Mermaid => "mermaid",
Self::Dot => "dot",
Self::Json => "json",
}
}
}
impl std::fmt::Display for OutputFormat {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.as_str())
}
}
#[derive(Debug, Clone, Default, PartialEq, Serialize, Deserialize)]
pub struct ExtractedEntities {
pub symbols: Vec<String>,
pub languages: Vec<String>,
pub paths: Vec<String>,
pub kind: Option<SymbolKind>,
pub limit: Option<u32>,
pub depth: Option<u32>,
pub format: Option<OutputFormat>,
pub from_symbol: Option<String>,
pub to_symbol: Option<String>,
pub relation: Option<String>,
pub predicate_type: Option<PredicateType>,
pub impl_trait: Option<String>,
pub predicate_arg: Option<String>,
pub visibility: Option<Visibility>,
pub is_async: Option<bool>,
pub is_unsafe: Option<bool>,
}
impl ExtractedEntities {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn has_symbols(&self) -> bool {
!self.symbols.is_empty()
}
#[must_use]
pub fn has_trace_path(&self) -> bool {
self.from_symbol.is_some() && self.to_symbol.is_some()
}
#[must_use]
pub fn primary_symbol(&self) -> Option<&str> {
self.symbols.first().map(String::as_str)
}
#[must_use]
pub fn has_predicate(&self) -> bool {
self.predicate_type.is_some()
|| self.impl_trait.is_some()
|| self.visibility.is_some()
|| self.is_async.is_some()
|| self.is_unsafe.is_some()
}
#[must_use]
pub fn is_impl_query(&self) -> bool {
self.predicate_type == Some(PredicateType::Impl) || self.impl_trait.is_some()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum TranslationResponse {
Execute {
command: String,
confidence: f32,
intent: Intent,
cached: bool,
latency_ms: u64,
},
Confirm {
command: String,
confidence: f32,
prompt: String,
},
Disambiguate {
options: Vec<DisambiguationOption>,
prompt: String,
},
Reject {
reason: String,
suggestions: Vec<String>,
},
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DisambiguationOption {
pub command: String,
pub intent: Intent,
pub description: String,
pub confidence: f32,
}
#[derive(Debug, Clone)]
pub struct PreprocessResult {
pub text: String,
pub quoted_spans: Vec<String>,
pub normalized: bool,
pub homoglyphs_replaced: bool,
}
impl PreprocessResult {
#[must_use]
pub fn ok(text: String, quoted_spans: Vec<String>) -> Self {
Self {
text,
quoted_spans,
normalized: false,
homoglyphs_replaced: false,
}
}
}
#[derive(Debug, Clone)]
pub struct ClassificationResult {
pub intent: Intent,
pub confidence: f32,
pub all_probabilities: Vec<f32>,
pub model_version: String,
}
#[derive(Debug, Clone)]
pub struct AssembledCommand {
pub command: String,
pub template_type: TemplateType,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TemplateType {
Query,
Search,
TracePath,
GraphCallers,
GraphCallees,
Visualize,
IndexStatus,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_intent_round_trip() {
for i in 0..Intent::NUM_CLASSES {
let intent = Intent::from_index(i);
assert_eq!(intent.to_index(), i);
}
}
#[test]
fn test_intent_display() {
assert_eq!(Intent::SymbolQuery.to_string(), "symbol_query");
assert_eq!(Intent::FindCallers.to_string(), "find_callers");
}
#[test]
fn test_validation_status_is_valid() {
assert!(ValidationStatus::Valid.is_valid());
assert!(!ValidationStatus::RejectedMetachar.is_valid());
}
#[test]
fn test_extracted_entities_default() {
let entities = ExtractedEntities::new();
assert!(!entities.has_symbols());
assert!(!entities.has_trace_path());
assert!(entities.primary_symbol().is_none());
}
#[test]
fn test_extracted_entities_with_symbols() {
let mut entities = ExtractedEntities::new();
entities.symbols.push("foo".to_string());
entities.symbols.push("bar".to_string());
assert!(entities.has_symbols());
assert_eq!(entities.primary_symbol(), Some("foo"));
}
#[test]
fn test_translation_response_serde() {
let response = TranslationResponse::Execute {
command: "sqry query \"test\"".to_string(),
confidence: 0.95,
intent: Intent::SymbolQuery,
cached: false,
latency_ms: 42,
};
let json = serde_json::to_string(&response).unwrap();
assert!(json.contains("\"type\":\"execute\""));
assert!(json.contains("symbol_query"));
let parsed: TranslationResponse = serde_json::from_str(&json).unwrap();
if let TranslationResponse::Execute { confidence, .. } = parsed {
assert!((confidence - 0.95).abs() < f32::EPSILON);
} else {
panic!("Wrong variant");
}
}
}