Skip to main content

forgeai_core/
lib.rs

1//! Core domain types and adapter traits for forgeai-rs.
2
3use async_trait::async_trait;
4use futures_core::Stream;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7use std::pin::Pin;
8use url::Url;
9
10pub type BoxError = Box<dyn std::error::Error + Send + Sync + 'static>;
11pub type StreamResult<T> = Pin<Box<dyn Stream<Item = Result<T, ForgeError>> + Send>>;
12
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct ChatRequest {
15    pub model: String,
16    pub messages: Vec<Message>,
17    pub temperature: Option<f32>,
18    pub max_tokens: Option<u32>,
19    pub tools: Vec<ToolDefinition>,
20    pub metadata: Value,
21}
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
24pub struct Message {
25    pub role: Role,
26    pub content: String,
27}
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
30#[serde(rename_all = "snake_case")]
31pub enum Role {
32    System,
33    User,
34    Assistant,
35    Tool,
36}
37
38#[derive(Debug, Clone, Serialize, Deserialize)]
39pub struct ToolDefinition {
40    pub name: String,
41    pub description: Option<String>,
42    pub input_schema: Value,
43}
44
45#[derive(Debug, Clone, Serialize, Deserialize)]
46pub struct ChatResponse {
47    pub id: String,
48    pub model: String,
49    pub output_text: String,
50    pub tool_calls: Vec<ToolCall>,
51    pub usage: Option<Usage>,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
55pub struct ToolCall {
56    pub id: String,
57    pub name: String,
58    pub arguments: Value,
59}
60
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct Usage {
63    pub input_tokens: u32,
64    pub output_tokens: u32,
65    pub total_tokens: u32,
66}
67
68#[derive(Debug, Clone, Serialize, Deserialize)]
69#[serde(tag = "type", rename_all = "snake_case")]
70pub enum StreamEvent {
71    TextDelta { delta: String },
72    ToolCallDelta { call_id: String, delta: Value },
73    Usage { usage: Usage },
74    Done,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
78pub struct CapabilityMatrix {
79    pub streaming: bool,
80    pub tools: bool,
81    pub structured_output: bool,
82    pub multimodal_input: bool,
83    pub citations: bool,
84}
85
86#[derive(Debug, Clone, Serialize, Deserialize)]
87pub struct AdapterInfo {
88    pub name: String,
89    pub base_url: Option<Url>,
90    pub capabilities: CapabilityMatrix,
91}
92
93#[derive(Debug, thiserror::Error)]
94pub enum ForgeError {
95    #[error("validation error: {0}")]
96    Validation(String),
97    #[error("authentication error")]
98    Authentication,
99    #[error("rate limited")]
100    RateLimited,
101    #[error("provider error: {0}")]
102    Provider(String),
103    #[error("transport error: {0}")]
104    Transport(String),
105    #[error("internal error: {0}")]
106    Internal(String),
107}
108
109#[async_trait]
110pub trait ChatAdapter: Send + Sync {
111    fn info(&self) -> AdapterInfo;
112
113    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError>;
114
115    async fn chat_stream(
116        &self,
117        request: ChatRequest,
118    ) -> Result<StreamResult<StreamEvent>, ForgeError>;
119}
120
121pub fn validate_request(request: &ChatRequest) -> Result<(), ForgeError> {
122    if request.model.trim().is_empty() {
123        return Err(ForgeError::Validation("model cannot be empty".to_string()));
124    }
125    if request.messages.is_empty() {
126        return Err(ForgeError::Validation(
127            "messages cannot be empty".to_string(),
128        ));
129    }
130    Ok(())
131}