Skip to main content

ares/types/
mod.rs

1//! Core types used throughout the A.R.E.S server.
2//!
3//! This module contains all the common data structures used for:
4//! - API requests and responses
5//! - Agent configuration and context
6//! - Memory and user preferences
7//! - Tool definitions and calls
8//! - RAG (Retrieval Augmented Generation)
9//! - Authentication
10//! - Error handling
11
12use chrono::{DateTime, Utc};
13use serde::{Deserialize, Serialize};
14use utoipa::ToSchema;
15
16/// Default datetime for serde deserialization
17fn default_datetime() -> DateTime<Utc> {
18    Utc::now()
19}
20
21// ============= API Request/Response Types =============
22
23/// Request payload for chat endpoints.
24#[derive(Debug, Serialize, Deserialize, ToSchema)]
25pub struct ChatRequest {
26    /// The user's message to send to the agent.
27    pub message: String,
28    /// Optional agent type to handle the request. Defaults to router.
29    #[serde(skip_serializing_if = "Option::is_none")]
30    pub agent_type: Option<AgentType>,
31    /// Optional context ID for conversation continuity.
32    #[serde(skip_serializing_if = "Option::is_none")]
33    pub context_id: Option<String>,
34    /// Optional Eruka workspace_id for per-user context isolation.
35    /// When set, the Eruka context middleware queries this workspace instead of the default.
36    #[serde(skip_serializing_if = "Option::is_none")]
37    pub workspace_id: Option<String>,
38}
39
40/// Response from chat endpoints.
41#[derive(Debug, Serialize, Deserialize, ToSchema)]
42pub struct ChatResponse {
43    /// The agent's response text.
44    pub response: String,
45    /// The name of the agent that handled the request.
46    pub agent: String,
47    /// Context ID for continuing this conversation.
48    pub context_id: String,
49    /// Optional sources used to generate the response.
50    pub sources: Option<Vec<Source>>,
51}
52
53/// A source reference used in responses.
54#[derive(Debug, Serialize, Deserialize, ToSchema, Clone)]
55pub struct Source {
56    /// Title of the source document or webpage.
57    pub title: String,
58    /// URL of the source, if available.
59    pub url: Option<String>,
60    /// Relevance score (0.0 to 1.0) indicating how relevant this source is.
61    pub relevance_score: f32,
62}
63
64/// Request payload for deep research endpoints.
65#[derive(Debug, Serialize, Deserialize, ToSchema)]
66pub struct ResearchRequest {
67    /// The research query or question.
68    pub query: String,
69    /// Optional maximum depth for recursive research (default: 3).
70    pub depth: Option<u8>,
71    /// Optional maximum iterations across all agents (default: 10).
72    pub max_iterations: Option<u8>,
73}
74
75/// Response from deep research endpoints.
76#[derive(Debug, Serialize, Deserialize, ToSchema)]
77pub struct ResearchResponse {
78    /// The compiled research findings.
79    pub findings: String,
80    /// Sources discovered during research.
81    pub sources: Vec<Source>,
82    /// Time taken for the research in milliseconds.
83    pub duration_ms: u64,
84}
85
86// ============= RAG API Types =============
87
88/// Request to ingest a document into the RAG system.
89#[derive(Debug, Serialize, Deserialize, ToSchema)]
90pub struct RagIngestRequest {
91    /// Collection name to ingest into.
92    pub collection: String,
93    /// The text content to ingest.
94    pub content: String,
95    /// Optional document title.
96    pub title: Option<String>,
97    /// Optional source URL or path.
98    pub source: Option<String>,
99    /// Optional tags for categorization.
100    #[serde(default)]
101    pub tags: Vec<String>,
102    /// Chunking strategy to use.
103    #[serde(default)]
104    pub chunking_strategy: Option<String>,
105}
106
107/// Response from document ingestion.
108#[derive(Debug, Serialize, Deserialize, ToSchema)]
109pub struct RagIngestResponse {
110    /// Number of chunks created.
111    pub chunks_created: usize,
112    /// Document IDs created.
113    pub document_ids: Vec<String>,
114    /// Collection name.
115    pub collection: String,
116}
117
118/// Request to search the RAG system.
119#[derive(Debug, Serialize, Deserialize, ToSchema)]
120pub struct RagSearchRequest {
121    /// Collection to search.
122    pub collection: String,
123    /// The search query.
124    pub query: String,
125    /// Maximum results to return (default: 10).
126    #[serde(default = "default_search_limit")]
127    pub limit: usize,
128    /// Search strategy to use: semantic, bm25, fuzzy, hybrid.
129    #[serde(default)]
130    pub strategy: Option<String>,
131    /// Minimum similarity threshold (0.0 to 1.0).
132    #[serde(default = "default_search_threshold")]
133    pub threshold: f32,
134    /// Whether to enable reranking.
135    #[serde(default)]
136    pub rerank: bool,
137    /// Reranker model to use if reranking.
138    #[serde(default)]
139    pub reranker_model: Option<String>,
140}
141
142fn default_search_limit() -> usize {
143    10
144}
145
146fn default_search_threshold() -> f32 {
147    0.0
148}
149
150/// Single search result.
151#[derive(Debug, Serialize, Deserialize, ToSchema)]
152pub struct RagSearchResult {
153    /// Document ID.
154    pub id: String,
155    /// Matching text content.
156    pub content: String,
157    /// Relevance score.
158    pub score: f32,
159    /// Document metadata.
160    pub metadata: DocumentMetadata,
161}
162
163/// Response from RAG search.
164#[derive(Debug, Serialize, Deserialize, ToSchema)]
165pub struct RagSearchResponse {
166    /// Search results.
167    pub results: Vec<RagSearchResult>,
168    /// Total number of results before limit.
169    pub total: usize,
170    /// Search strategy used.
171    pub strategy: String,
172    /// Whether reranking was applied.
173    pub reranked: bool,
174    /// Query processing time in milliseconds.
175    pub duration_ms: u64,
176}
177
178/// Request to delete a collection.
179#[derive(Debug, Serialize, Deserialize, ToSchema)]
180pub struct RagDeleteCollectionRequest {
181    /// Collection name to delete.
182    pub collection: String,
183}
184
185/// Response from collection deletion.
186#[derive(Debug, Serialize, Deserialize, ToSchema)]
187pub struct RagDeleteCollectionResponse {
188    /// Whether deletion was successful.
189    pub success: bool,
190    /// Collection that was deleted.
191    pub collection: String,
192    /// Number of documents deleted.
193    pub documents_deleted: usize,
194}
195
196// ============= Workflow Types =============
197
198/// Request payload for workflow execution endpoints.
199#[derive(Debug, Serialize, Deserialize, ToSchema)]
200pub struct WorkflowRequest {
201    /// The query to process through the workflow.
202    pub query: String,
203    /// Additional context data as key-value pairs.
204    #[serde(default)]
205    pub context: std::collections::HashMap<String, serde_json::Value>,
206}
207
208// ============= Agent Types =============
209
210/// Available agent types in the system.
211///
212/// This enum supports both built-in agent types and custom user-defined agents.
213/// The `Custom` variant allows for extensibility without modifying this enum.
214#[derive(Debug, Serialize, Deserialize, ToSchema, Clone, PartialEq, Eq)]
215#[serde(rename_all = "lowercase")]
216#[non_exhaustive]
217pub enum AgentType {
218    /// Routes requests to appropriate specialized agents.
219    Router,
220    /// Orchestrates complex multi-step tasks.
221    Orchestrator,
222    /// Handles product-related queries.
223    Product,
224    /// Handles invoice and billing queries.
225    Invoice,
226    /// Handles sales-related queries.
227    Sales,
228    /// Handles financial queries and analysis.
229    Finance,
230    /// Handles HR and employee-related queries.
231    #[serde(rename = "hr")]
232    HR,
233    /// Custom user-defined agent type.
234    /// The string contains the agent's unique identifier/name.
235    #[serde(untagged)]
236    Custom(String),
237}
238
239impl AgentType {
240    /// Returns the agent type name as a string slice.
241    pub fn as_str(&self) -> &str {
242        match self {
243            AgentType::Router => "router",
244            AgentType::Orchestrator => "orchestrator",
245            AgentType::Product => "product",
246            AgentType::Invoice => "invoice",
247            AgentType::Sales => "sales",
248            AgentType::Finance => "finance",
249            AgentType::HR => "hr",
250            AgentType::Custom(name) => name,
251        }
252    }
253
254    /// Creates an AgentType from a string, using built-in types when possible.
255    pub fn from_string(s: &str) -> Self {
256        match s.to_lowercase().as_str() {
257            "router" => AgentType::Router,
258            "orchestrator" => AgentType::Orchestrator,
259            "product" => AgentType::Product,
260            "invoice" => AgentType::Invoice,
261            "sales" => AgentType::Sales,
262            "finance" => AgentType::Finance,
263            "hr" => AgentType::HR,
264            _ => AgentType::Custom(s.to_string()),
265        }
266    }
267
268    /// Returns true if this is a built-in agent type.
269    pub fn is_builtin(&self) -> bool {
270        !matches!(self, AgentType::Custom(_))
271    }
272}
273
274impl std::fmt::Display for AgentType {
275    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
276        write!(f, "{}", self.as_str())
277    }
278}
279
280/// Context passed to agents during request processing.
281#[derive(Debug, Clone)]
282pub struct AgentContext {
283    /// Unique identifier for the user making the request.
284    pub user_id: String,
285    /// Session identifier for conversation tracking.
286    pub session_id: String,
287    /// Previous messages in the conversation.
288    pub conversation_history: Vec<Message>,
289    /// User's stored memory and preferences.
290    pub user_memory: Option<UserMemory>,
291}
292
293/// A single message in a conversation.
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct Message {
296    /// The role of the message sender.
297    pub role: MessageRole,
298    /// The message content.
299    pub content: String,
300    /// When the message was sent.
301    pub timestamp: DateTime<Utc>,
302}
303
304/// Role of a message sender in a conversation.
305#[derive(Debug, Clone, Serialize, Deserialize)]
306#[serde(rename_all = "lowercase")]
307pub enum MessageRole {
308    /// System instructions to the model.
309    System,
310    /// Message from the user.
311    User,
312    /// Response from the assistant/agent.
313    Assistant,
314}
315
316// ============= Memory Types =============
317
318/// User memory containing preferences and learned facts.
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct UserMemory {
321    /// The user's unique identifier.
322    pub user_id: String,
323    /// List of user preferences.
324    pub preferences: Vec<Preference>,
325    /// List of facts learned about the user.
326    pub facts: Vec<MemoryFact>,
327}
328
329/// A user preference entry.
330#[derive(Debug, Clone, Serialize, Deserialize)]
331pub struct Preference {
332    /// Category of the preference (e.g., "communication", "output").
333    pub category: String,
334    /// Key identifying the specific preference.
335    pub key: String,
336    /// The preference value.
337    pub value: String,
338    /// Confidence score (0.0 to 1.0) for this preference.
339    pub confidence: f32,
340}
341
342/// A fact learned about a user.
343#[derive(Debug, Clone, Serialize, Deserialize)]
344pub struct MemoryFact {
345    /// Unique identifier for this fact.
346    pub id: String,
347    /// The user this fact belongs to.
348    pub user_id: String,
349    /// Category of the fact (e.g., "personal", "work").
350    pub category: String,
351    /// Key identifying the specific fact.
352    pub fact_key: String,
353    /// The fact value.
354    pub fact_value: String,
355    /// Confidence score (0.0 to 1.0) for this fact.
356    pub confidence: f32,
357    /// When this fact was first recorded.
358    pub created_at: DateTime<Utc>,
359    /// When this fact was last updated.
360    pub updated_at: DateTime<Utc>,
361}
362
363// ============= Tool Types =============
364
365/// Definition of a tool that can be called by an LLM.
366#[derive(Debug, Serialize, Deserialize, Clone)]
367pub struct ToolDefinition {
368    /// Unique name of the tool.
369    pub name: String,
370    /// Human-readable description of what the tool does.
371    pub description: String,
372    /// JSON Schema defining the tool's parameters.
373    pub parameters: serde_json::Value,
374}
375
376/// A request to call a tool.
377#[derive(Debug, Serialize, Deserialize, Clone)]
378pub struct ToolCall {
379    /// Unique identifier for this tool call.
380    pub id: String,
381    /// Name of the tool to call.
382    pub name: String,
383    /// Arguments to pass to the tool.
384    pub arguments: serde_json::Value,
385}
386
387/// Result from executing a tool.
388#[derive(Debug, Serialize, Deserialize)]
389pub struct ToolResult {
390    /// ID of the tool call this result corresponds to.
391    pub tool_call_id: String,
392    /// The result data from the tool execution.
393    pub result: serde_json::Value,
394}
395
396// ============= RAG Types =============
397
398/// A document in the RAG knowledge base.
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct Document {
401    /// Unique identifier for the document.
402    pub id: String,
403    /// The document's text content.
404    pub content: String,
405    /// Metadata about the document.
406    pub metadata: DocumentMetadata,
407    /// Optional embedding vector for semantic search.
408    pub embedding: Option<Vec<f32>>,
409}
410
411/// Metadata associated with a document.
412#[derive(Debug, Clone, Default, Serialize, Deserialize, ToSchema)]
413pub struct DocumentMetadata {
414    /// Title of the document.
415    #[serde(default)]
416    pub title: String,
417    /// Source of the document (e.g., URL, file path).
418    #[serde(default)]
419    pub source: String,
420    /// When the document was created or ingested.
421    #[serde(default = "default_datetime")]
422    pub created_at: DateTime<Utc>,
423    /// Tags for categorization and filtering.
424    #[serde(default)]
425    pub tags: Vec<String>,
426}
427
428/// Query parameters for semantic search.
429#[derive(Debug, Clone)]
430pub struct SearchQuery {
431    /// The search query text.
432    pub query: String,
433    /// Maximum number of results to return.
434    pub limit: usize,
435    /// Minimum similarity threshold (0.0 to 1.0).
436    pub threshold: f32,
437    /// Optional filters to apply to results.
438    pub filters: Option<Vec<SearchFilter>>,
439}
440
441/// A filter to apply during search.
442#[derive(Debug, Clone)]
443pub struct SearchFilter {
444    /// Field name to filter on.
445    pub field: String,
446    /// Value to filter by.
447    pub value: String,
448}
449
450/// A single search result with relevance score.
451#[derive(Debug, Clone)]
452pub struct SearchResult {
453    /// The matching document.
454    pub document: Document,
455    /// Similarity score (0.0 to 1.0).
456    pub score: f32,
457}
458
459// ============= Authentication Types =============
460
461/// Request payload for user login.
462#[derive(Debug, Serialize, Deserialize, ToSchema)]
463pub struct LoginRequest {
464    /// User's email address.
465    pub email: String,
466    /// User's password.
467    pub password: String,
468}
469
470/// Request payload for user registration.
471#[derive(Debug, Serialize, Deserialize, ToSchema)]
472pub struct RegisterRequest {
473    /// Email address for the new account.
474    pub email: String,
475    /// Password for the new account.
476    pub password: String,
477    /// Display name for the user.
478    pub name: String,
479}
480
481/// Response containing authentication tokens.
482#[derive(Debug, Serialize, Deserialize, ToSchema)]
483pub struct TokenResponse {
484    /// JWT access token for API authentication.
485    pub access_token: String,
486    /// Refresh token for obtaining new access tokens.
487    pub refresh_token: String,
488    /// Time in seconds until the access token expires.
489    pub expires_in: i64,
490}
491
492/// JWT claims embedded in access tokens.
493#[derive(Debug, Serialize, Deserialize, Clone)]
494pub struct Claims {
495    /// Subject (user ID).
496    pub sub: String,
497    /// User's email address.
498    pub email: String,
499    /// Expiration time (Unix timestamp).
500    pub exp: usize,
501    /// Issued at time (Unix timestamp).
502    pub iat: usize,
503}
504
505// ============= Error Types =============
506
507/// Error codes for programmatic error handling.
508/// These are stable identifiers that clients can use to handle specific error cases.
509#[derive(Debug, Clone, Copy, Serialize)]
510#[serde(rename_all = "SCREAMING_SNAKE_CASE")]
511pub enum ErrorCode {
512    /// Database operation failed
513    DatabaseError,
514    /// LLM/AI model operation failed
515    LlmError,
516    /// Authentication failed (invalid credentials)
517    AuthenticationFailed,
518    /// Authorization failed (valid credentials but insufficient permissions)
519    AuthorizationFailed,
520    /// Requested resource was not found
521    NotFound,
522    /// Input validation failed
523    InvalidInput,
524    /// Server configuration error
525    ConfigurationError,
526    /// External service (API, webhook, etc.) failed
527    ExternalServiceError,
528    /// Internal server error
529    InternalError,
530}
531
532/// Application-wide error type.
533#[derive(Debug, thiserror::Error)]
534pub enum AppError {
535    /// Database operation failed.
536    #[error("Database error: {0}")]
537    Database(String),
538
539    /// LLM operation failed.
540    #[error("LLM error: {0}")]
541    LLM(String),
542
543    /// Authentication or authorization failed.
544    #[error("Authentication error: {0}")]
545    Auth(String),
546
547    /// Requested resource was not found.
548    #[error("Not found: {0}")]
549    NotFound(String),
550
551    /// Input validation failed.
552    #[error("Invalid input: {0}")]
553    InvalidInput(String),
554
555    /// Configuration error.
556    #[error("Configuration error: {0}")]
557    Configuration(String),
558
559    /// External service call failed.
560    #[error("External service error: {0}")]
561    External(String),
562
563    /// Internal server error.
564    #[error("Internal error: {0}")]
565    Internal(String),
566
567    /// Service temporarily unavailable (emergency stop, maintenance).
568    #[error("Service unavailable: {0}")]
569    Unavailable(String),
570
571    /// Rate limit / quota exceeded.
572    #[error("Rate limited: {0}")]
573    RateLimited(String),
574}
575
576impl AppError {
577    /// Get the error code for this error type.
578    pub fn code(&self) -> ErrorCode {
579        match self {
580            AppError::Database(_) => ErrorCode::DatabaseError,
581            AppError::LLM(_) => ErrorCode::LlmError,
582            AppError::Auth(_) => ErrorCode::AuthenticationFailed,
583            AppError::NotFound(_) => ErrorCode::NotFound,
584            AppError::InvalidInput(_) => ErrorCode::InvalidInput,
585            AppError::Configuration(_) => ErrorCode::ConfigurationError,
586            AppError::External(_) => ErrorCode::ExternalServiceError,
587            AppError::Internal(_) => ErrorCode::InternalError,
588            AppError::Unavailable(_) => ErrorCode::InternalError,
589            AppError::RateLimited(_) => ErrorCode::InternalError,
590        }
591    }
592
593    /// Check if this is an internal error that should be logged.
594    fn is_internal(&self) -> bool {
595        matches!(
596            self,
597            AppError::Database(_)
598                | AppError::LLM(_)
599                | AppError::Configuration(_)
600                | AppError::Internal(_)
601        )
602    }
603}
604
605// ============= Error Conversions =============
606
607impl From<std::io::Error> for AppError {
608    fn from(err: std::io::Error) -> Self {
609        AppError::Internal(format!("IO error: {}", err))
610    }
611}
612
613impl From<serde_json::Error> for AppError {
614    fn from(err: serde_json::Error) -> Self {
615        AppError::InvalidInput(format!("JSON error: {}", err))
616    }
617}
618
619impl axum::response::IntoResponse for AppError {
620    fn into_response(self) -> axum::response::Response {
621        // Log internal errors before returning
622        if self.is_internal() {
623            tracing::error!(error = %self, code = ?self.code(), "Internal error occurred");
624        }
625
626        let (status, message) = match &self {
627            AppError::Database(msg) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
628            AppError::LLM(msg) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
629            AppError::Auth(msg) => (axum::http::StatusCode::UNAUTHORIZED, msg.clone()),
630            AppError::NotFound(msg) => (axum::http::StatusCode::NOT_FOUND, msg.clone()),
631            AppError::InvalidInput(msg) => (axum::http::StatusCode::BAD_REQUEST, msg.clone()),
632            AppError::Configuration(msg) => {
633                (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone())
634            }
635            AppError::External(msg) => (axum::http::StatusCode::BAD_GATEWAY, msg.clone()),
636            AppError::Internal(msg) => (axum::http::StatusCode::INTERNAL_SERVER_ERROR, msg.clone()),
637            AppError::Unavailable(msg) => (axum::http::StatusCode::SERVICE_UNAVAILABLE, msg.clone()),
638            AppError::RateLimited(msg) => (axum::http::StatusCode::TOO_MANY_REQUESTS, msg.clone()),
639        };
640
641        let body = serde_json::json!({
642            "error": message,
643            "code": self.code()
644        });
645
646        (status, axum::Json(body)).into_response()
647    }
648}
649
650/// A specialized Result type for A.R.E.S operations.
651pub type Result<T> = std::result::Result<T, AppError>;