Skip to main content

ares/types/
mod.rs

1//! Core types used throughout the A.R.E.S server.
2//!
3//! This module contains all the common data structures used for:
4//! - API requests and responses
5//! - Agent configuration and context
6//! - Memory and user preferences
7//! - Tool definitions and calls
8//! - RAG (Retrieval Augmented Generation)
9//! - Authentication
10//! - Error handling
11
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use utoipa::ToSchema;
15
16/// Default datetime for serde deserialization
17fn default_datetime() -> DateTime<Utc> {
18    Utc::now()
19}
20
21// ============= API Request/Response Types =============
22
23/// Request payload for chat endpoints.
24#[derive(Debug, Serialize, Deserialize, ToSchema)]
25pub struct ChatRequest {
26    /// The user's message to send to the agent.
27    pub message: String,
28    /// Optional agent type to handle the request. Defaults to router.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub agent_type: Option<AgentType>,
31    /// Optional context ID for conversation continuity.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub context_id: Option<String>,
34}
35
36/// Response from chat endpoints.
37#[derive(Debug, Serialize, Deserialize, ToSchema)]
38pub struct ChatResponse {
39    /// The agent's response text.
40    pub response: String,
41    /// The name of the agent that handled the request.
42    pub agent: String,
43    /// Context ID for continuing this conversation.
44    pub context_id: String,
45    /// Optional sources used to generate the response.
46    pub sources: Option<Vec<Source>>,
47}
48
49/// A source reference used in responses.
50#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
51pub struct Source {
52    /// Title of the source document or webpage.
53    pub title: String,
54    /// URL of the source, if available.
55    pub url: Option<String>,
56    /// Relevance score (0.0 to 1.0) indicating how relevant this source is.
57    pub relevance_score: f32,
58}
59
60/// Request payload for deep research endpoints.
61#[derive(Debug, Serialize, Deserialize, ToSchema)]
62pub struct ResearchRequest {
63    /// The research query or question.
64    pub query: String,
65    /// Optional maximum depth for recursive research (default: 3).
66    pub depth: Option<u8>,
67    /// Optional maximum iterations across all agents (default: 10).
68    pub max_iterations: Option<u8>,
69}
70
71/// Response from deep research endpoints.
72#[derive(Debug, Serialize, Deserialize, ToSchema)]
73pub struct ResearchResponse {
74    /// The compiled research findings.
75    pub findings: String,
76    /// Sources discovered during research.
77    pub sources: Vec<Source>,
78    /// Time taken for the research in milliseconds.
79    pub duration_ms: u64,
80}
81
82// ============= RAG API Types =============
83
84/// Request to ingest a document into the RAG system.
85#[derive(Debug, Serialize, Deserialize, ToSchema)]
86pub struct RagIngestRequest {
87    /// Collection name to ingest into.
88    pub collection: String,
89    /// The text content to ingest.
90    pub content: String,
91    /// Optional document title.
92    pub title: Option<String>,
93    /// Optional source URL or path.
94    pub source: Option<String>,
95    /// Optional tags for categorization.
96    #[serde(default)]
97    pub tags: Vec<String>,
98    /// Chunking strategy to use.
99    #[serde(default)]
100    pub chunking_strategy: Option<String>,
101}
102
103/// Response from document ingestion.
104#[derive(Debug, Serialize, Deserialize, ToSchema)]
105pub struct RagIngestResponse {
106    /// Number of chunks created.
107    pub chunks_created: usize,
108    /// Document IDs created.
109    pub document_ids: Vec<String>,
110    /// Collection name.
111    pub collection: String,
112}
113
114/// Request to search the RAG system.
115#[derive(Debug, Serialize, Deserialize, ToSchema)]
116pub struct RagSearchRequest {
117    /// Collection to search.
118    pub collection: String,
119    /// The search query.
120    pub query: String,
121    /// Maximum results to return (default: 10).
122    #[serde(default = "default_search_limit")]
123    pub limit: usize,
124    /// Search strategy to use: semantic, bm25, fuzzy, hybrid.
125    #[serde(default)]
126    pub strategy: Option<String>,
127    /// Minimum similarity threshold (0.0 to 1.0).
128    #[serde(default = "default_search_threshold")]
129    pub threshold: f32,
130    /// Whether to enable reranking.
131    #[serde(default)]
132    pub rerank: bool,
133    /// Reranker model to use if reranking.
134    #[serde(default)]
135    pub reranker_model: Option<String>,
136}
137
138fn default_search_limit() -> usize {
139    10
140}
141
142fn default_search_threshold() -> f32 {
143    0.0
144}
145
146/// Single search result.
147#[derive(Debug, Serialize, Deserialize, ToSchema)]
148pub struct RagSearchResult {
149    /// Document ID.
150    pub id: String,
151    /// Matching text content.
152    pub content: String,
153    /// Relevance score.
154    pub score: f32,
155    /// Document metadata.
156    pub metadata: DocumentMetadata,
157}
158
159/// Response from RAG search.
160#[derive(Debug, Serialize, Deserialize, ToSchema)]
161pub struct RagSearchResponse {
162    /// Search results.
163    pub results: Vec<RagSearchResult>,
164    /// Total number of results before limit.
165    pub total: usize,
166    /// Search strategy used.
167    pub strategy: String,
168    /// Whether reranking was applied.
169    pub reranked: bool,
170    /// Query processing time in milliseconds.
171    pub duration_ms: u64,
172}
173
174/// Request to delete a collection.
175#[derive(Debug, Serialize, Deserialize, ToSchema)]
176pub struct RagDeleteCollectionRequest {
177    /// Collection name to delete.
178    pub collection: String,
179}
180
181/// Response from collection deletion.
182#[derive(Debug, Serialize, Deserialize, ToSchema)]
183pub struct RagDeleteCollectionResponse {
184    /// Whether deletion was successful.
185    pub success: bool,
186    /// Collection that was deleted.
187    pub collection: String,
188    /// Number of documents deleted.
189    pub documents_deleted: usize,
190}
191
192// ============= Workflow Types =============
193
194/// Request payload for workflow execution endpoints.
195#[derive(Debug, Serialize, Deserialize, ToSchema)]
196pub struct WorkflowRequest {
197    /// The query to process through the workflow.
198    pub query: String,
199    /// Additional context data as key-value pairs.
200    #[serde(default)]
201    pub context: std::collections::HashMap<String, serde_json::Value>,
202}
203
204// ============= Agent Types =============
205
206/// Available agent types in the system.
207///
208/// This enum supports both built-in agent types and custom user-defined agents.
209/// The `Custom` variant allows for extensibility without modifying this enum.
210#[derive(Debug, Serialize, Deserialize, ToSchema, Clone, PartialEq, Eq)]
211#[serde(rename_all = "lowercase")]
212#[non_exhaustive]
213pub enum AgentType {
214    /// Routes requests to appropriate specialized agents.
215    Router,
216    /// Orchestrates complex multi-step tasks.
217    Orchestrator,
218    /// Handles product-related queries.
219    Product,
220    /// Handles invoice and billing queries.
221    Invoice,
222    /// Handles sales-related queries.
223    Sales,
224    /// Handles financial queries and analysis.
225    Finance,
226    /// Handles HR and employee-related queries.
227    #[serde(rename = "hr")]
228    HR,
229    /// Custom user-defined agent type.
230    /// The string contains the agent's unique identifier/name.
231    #[serde(untagged)]
232    Custom(String),
233}
234
235impl AgentType {
236    /// Returns the agent type name as a string slice.
237    pub fn as_str(&self) -> &str {
238        match self {
239            AgentType::Router => "router",
240            AgentType::Orchestrator => "orchestrator",
241            AgentType::Product => "product",
242            AgentType::Invoice => "invoice",
243            AgentType::Sales => "sales",
244            AgentType::Finance => "finance",
245            AgentType::HR => "hr",
246            AgentType::Custom(name) => name,
247        }
248    }
249
250    /// Creates an AgentType from a string, using built-in types when possible.
251    pub fn from_string(s: &str) -> Self {
252        match s.to_lowercase().as_str() {
253            "router" => AgentType::Router,
254            "orchestrator" => AgentType::Orchestrator,
255            "product" => AgentType::Product,
256            "invoice" => AgentType::Invoice,
257            "sales" => AgentType::Sales,
258            "finance" => AgentType::Finance,
259            "hr" => AgentType::HR,
260            _ => AgentType::Custom(s.to_string()),
261        }
262    }
263
264    /// Returns true if this is a built-in agent type.
265    pub fn is_builtin(&self) -> bool {
266        !matches!(self, AgentType::Custom(_))
267    }
268}
269
270impl std::fmt::Display for AgentType {
271    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
272        write!(f, "{}", self.as_str())
273    }
274}
275
276/// Context passed to agents during request processing.
277#[derive(Debug, Clone)]
278pub struct AgentContext {
279    /// Unique identifier for the user making the request.
280    pub user_id: String,
281    /// Session identifier for conversation tracking.
282    pub session_id: String,
283    /// Previous messages in the conversation.
284    pub conversation_history: Vec<Message>,
285    /// User's stored memory and preferences.
286    pub user_memory: Option<UserMemory>,
287}
288
289/// A single message in a conversation.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub struct Message {
292    /// The role of the message sender.
293    pub role: MessageRole,
294    /// The message content.
295    pub content: String,
296    /// When the message was sent.
297    pub timestamp: DateTime<Utc>,
298}
299
300/// Role of a message sender in a conversation.
301#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(rename_all = "lowercase")]
303pub enum MessageRole {
304    /// System instructions to the model.
305    System,
306    /// Message from the user.
307    User,
308    /// Response from the assistant/agent.
309    Assistant,
310}
311
312// ============= Memory Types =============
313
314/// User memory containing preferences and learned facts.
315#[derive(Debug, Clone, Serialize, Deserialize)]
316pub struct UserMemory {
317    /// The user's unique identifier.
318    pub user_id: String,
319    /// List of user preferences.
320    pub preferences: Vec<Preference>,
321    /// List of facts learned about the user.
322    pub facts: Vec<MemoryFact>,
323}
324
325/// A user preference entry.
326#[derive(Debug, Clone, Serialize, Deserialize)]
327pub struct Preference {
328    /// Category of the preference (e.g., "communication", "output").
329    pub category: String,
330    /// Key identifying the specific preference.
331    pub key: String,
332    /// The preference value.
333    pub value: String,
334    /// Confidence score (0.0 to 1.0) for this preference.
335    pub confidence: f32,
336}
337
338/// A fact learned about a user.
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct MemoryFact {
341    /// Unique identifier for this fact.
342    pub id: String,
343    /// The user this fact belongs to.
344    pub user_id: String,
345    /// Category of the fact (e.g., "personal", "work").
346    pub category: String,
347    /// Key identifying the specific fact.
348    pub fact_key: String,
349    /// The fact value.
350    pub fact_value: String,
351    /// Confidence score (0.0 to 1.0) for this fact.
352    pub confidence: f32,
353    /// When this fact was first recorded.
354    pub created_at: DateTime<Utc>,
355    /// When this fact was last updated.
356    pub updated_at: DateTime<Utc>,
357}
358
359// ============= Tool Types =============
360
361/// Definition of a tool that can be called by an LLM.
362#[derive(Debug, Serialize, Deserialize, Clone)]
363pub struct ToolDefinition {
364    /// Unique name of the tool.
365    pub name: String,
366    /// Human-readable description of what the tool does.
367    pub description: String,
368    /// JSON Schema defining the tool's parameters.
369    pub parameters: serde_json::Value,
370}
371
372/// A request to call a tool.
373#[derive(Debug, Serialize, Deserialize, Clone)]
374pub struct ToolCall {
375    /// Unique identifier for this tool call.
376    pub id: String,
377    /// Name of the tool to call.
378    pub name: String,
379    /// Arguments to pass to the tool.
380    pub arguments: serde_json::Value,
381}
382
383/// Result from executing a tool.
384#[derive(Debug, Serialize, Deserialize)]
385pub struct ToolResult {
386    /// ID of the tool call this result corresponds to.
387    pub tool_call_id: String,
388    /// The result data from the tool execution.
389    pub result: serde_json::Value,
390}
391
392// ============= RAG Types =============
393
394/// A document in the RAG knowledge base.
395#[derive(Debug, Clone, Serialize, Deserialize)]
396pub struct Document {
397    /// Unique identifier for the document.
398    pub id: String,
399    /// The document's text content.
400    pub content: String,
401    /// Metadata about the document.
402    pub metadata: DocumentMetadata,
403    /// Optional embedding vector for semantic search.
404    pub embedding: Option<Vec<f32>>,
405}
406
407/// Metadata associated with a document.
408#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]
409pub struct DocumentMetadata {
410    /// Title of the document.
411    #[serde(default)]
412    pub title: String,
413    /// Source of the document (e.g., URL, file path).
414    #[serde(default)]
415    pub source: String,
416    /// When the document was created or ingested.
417    #[serde(default = "default_datetime")]
418    pub created_at: DateTime<Utc>,
419    /// Tags for categorization and filtering.
420    #[serde(default)]
421    pub tags: Vec<String>,
422}
423
424/// Query parameters for semantic search.
425#[derive(Debug, Clone)]
426pub struct SearchQuery {
427    /// The search query text.
428    pub query: String,
429    /// Maximum number of results to return.
430    pub limit: usize,
431    /// Minimum similarity threshold (0.0 to 1.0).
432    pub threshold: f32,
433    /// Optional filters to apply to results.
434    pub filters: Option<Vec<SearchFilter>>,
435}
436
437/// A filter to apply during search.
438#[derive(Debug, Clone)]
439pub struct SearchFilter {
440    /// Field name to filter on.
441    pub field: String,
442    /// Value to filter by.
443    pub value: String,
444}
445
446/// A single search result with relevance score.
447#[derive(Debug, Clone)]
448pub struct SearchResult {
449    /// The matching document.
450    pub document: Document,
451    /// Similarity score (0.0 to 1.0).
452    pub score: f32,
453}
454
455// ============= Authentication Types =============
456
457/// Request payload for user login.
458#[derive(Debug, Serialize, Deserialize, ToSchema)]
459pub struct LoginRequest {
460    /// User's email address.
461    pub email: String,
462    /// User's password.
463    pub password: String,
464}
465
466/// Request payload for user registration.
467#[derive(Debug, Serialize, Deserialize, ToSchema)]
468pub struct RegisterRequest {
469    /// Email address for the new account.
470    pub email: String,
471    /// Password for the new account.
472    pub password: String,
473    /// Display name for the user.
474    pub name: String,
475}
476
477/// Response containing authentication tokens.
478#[derive(Debug, Serialize, Deserialize, ToSchema)]
479pub struct TokenResponse {
480    /// JWT access token for API authentication.
481    pub access_token: String,
482    /// Refresh token for obtaining new access tokens.
483    pub refresh_token: String,
484    /// Time in seconds until the access token expires.
485    pub expires_in: i64,
486}
487
488/// JWT claims embedded in access tokens.
489#[derive(Debug, Serialize, Deserialize, Clone)]
490pub struct Claims {
491    /// Subject (user ID).
492    pub sub: String,
493    /// User's email address.
494    pub email: String,
495    /// Expiration time (Unix timestamp).
496    pub exp: usize,
497    /// Issued at time (Unix timestamp).
498    pub iat: usize,
499}
500
501// ============= Error Types =============
502
503/// Error codes for programmatic error handling.
504/// These are stable identifiers that clients can use to handle specific error cases.
505#[derive(Debug, Clone, Copy, Serialize)]
506#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
507pub enum ErrorCode {
508    /// Database operation failed
509    DatabaseError,
510    /// LLM/AI model operation failed
511    LlmError,
512    /// Authentication failed (invalid credentials)
513    AuthenticationFailed,
514    /// Authorization failed (valid credentials but insufficient permissions)
515    AuthorizationFailed,
516    /// Requested resource was not found
517    NotFound,
518    /// Input validation failed
519    InvalidInput,
520    /// Server configuration error
521    ConfigurationError,
522    /// External service (API, webhook, etc.) failed
523    ExternalServiceError,
524    /// Internal server error
525    InternalError,
526}
527
528/// Application-wide error type.
529#[derive(Debug, thiserror::Error)]
530pub enum AppError {
531    /// Database operation failed.
532    #[error("Database error: {0}")]
533    Database(String),
534
535    /// LLM operation failed.
536    #[error("LLM error: {0}")]
537    LLM(String),
538
539    /// Authentication or authorization failed.
540    #[error("Authentication error: {0}")]
541    Auth(String),
542
543    /// Requested resource was not found.
544    #[error("Not found: {0}")]
545    NotFound(String),
546
547    /// Input validation failed.
548    #[error("Invalid input: {0}")]
549    InvalidInput(String),
550
551    /// Configuration error.
552    #[error("Configuration error: {0}")]
553    Configuration(String),
554
555    /// External service call failed.
556    #[error("External service error: {0}")]
557    External(String),
558
559    /// Internal server error.
560    #[error("Internal error: {0}")]
561    Internal(String),
562}
563
564impl AppError {
565    /// Get the error code for this error type.
566    pub fn code(&self) -> ErrorCode {
567        match self {
568            AppError::Database(_) => ErrorCode::DatabaseError,
569            AppError::LLM(_) => ErrorCode::LlmError,
570            AppError::Auth(_) => ErrorCode::AuthenticationFailed,
571            AppError::NotFound(_) => ErrorCode::NotFound,
572            AppError::InvalidInput(_) => ErrorCode::InvalidInput,
573            AppError::Configuration(_) => ErrorCode::ConfigurationError,
574            AppError::External(_) => ErrorCode::ExternalServiceError,
575            AppError::Internal(_) => ErrorCode::InternalError,
576        }
577    }
578
579    /// Check if this is an internal error that should be logged.
580    fn is_internal(&self) -> bool {
581        matches!(
582            self,
583            AppError::Database(_)
584                | AppError::LLM(_)
585                | AppError::Configuration(_)
586                | AppError::Internal(_)
587        )
588    }
589}
590
591// ============= Error Conversions =============
592
593impl From<std::io::Error> for AppError {
594    fn from(err: std::io::Error) -> Self {
595        AppError::Internal(format!("IO error: {}", err))
596    }
597}
598
599impl From<serde_json::Error> for AppError {
600    fn from(err: serde_json::Error) -> Self {
601        AppError::InvalidInput(format!("JSON error: {}", err))
602    }
603}
604
605impl axum::response::IntoResponse for AppError {
606    fn into_response(self) -> axum::response::Response {
607        // Log internal errors before returning
608        if self.is_internal() {
609            tracing::error!(error = %self, code = ?self.code(), "Internal error occurred");
610        }
611
612        let (status, message) = match &self {
613            AppError::Database(msg) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
614            AppError::LLM(msg) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
615            AppError::Auth(msg) => (axum::http::StatusCode::UNAUTHORIZED, msg.clone()),
616            AppError::NotFound(msg) => (axum::http::StatusCode::NOT_FOUND, msg.clone()),
617            AppError::InvalidInput(msg) => (axum::http::StatusCode::BAD_REQUEST, msg.clone()),
618            AppError::Configuration(msg) => {
619                (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone())
620            }
621            AppError::External(msg) => (axum::http::StatusCode::BAD_GATEWAY, msg.clone()),
622            AppError::Internal(msg) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
623        };
624
625        let body = serde_json::json!({
626            "error": message,
627            "code": self.code()
628        });
629
630        (status, axum::Json(body)).into_response()
631    }
632}
633
634/// A specialized Result type for A.R.E.S operations.
635pub type Result<T> = std::result::Result<T, AppError>;