Skip to main content

openclaw_providers/
traits.rs

1//! Provider traits.
2
3use async_trait::async_trait;
4use serde::{Deserialize, Serialize};
5use thiserror::Error;
6
7use openclaw_core::types::TokenUsage;
8use std::pin::Pin;
9
10/// Provider errors.
11#[derive(Error, Debug)]
12pub enum ProviderError {
13    /// API error.
14    #[error("API error: {status} - {message}")]
15    Api {
16        /// HTTP status code.
17        status: u16,
18        /// Error message.
19        message: String,
20    },
21
22    /// Network error.
23    #[error("Network error: {0}")]
24    Network(#[from] reqwest::Error),
25
26    /// Serialization error.
27    #[error("Serialization error: {0}")]
28    Serialization(#[from] serde_json::Error),
29
30    /// Rate limited.
31    #[error("Rate limited, retry after {retry_after_secs} seconds")]
32    RateLimited {
33        /// Seconds to wait before retry.
34        retry_after_secs: u64,
35    },
36
37    /// Invalid configuration.
38    #[error("Invalid configuration: {0}")]
39    Config(String),
40}
41
42/// Completion request.
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct CompletionRequest {
45    /// Model to use.
46    pub model: String,
47
48    /// Messages in conversation.
49    pub messages: Vec<Message>,
50
51    /// System prompt.
52    pub system: Option<String>,
53
54    /// Maximum tokens to generate.
55    pub max_tokens: u32,
56
57    /// Temperature for sampling.
58    pub temperature: f32,
59
60    /// Stop sequences.
61    pub stop: Option<Vec<String>>,
62
63    /// Tools available.
64    pub tools: Option<Vec<Tool>>,
65}
66
67/// A message in the conversation.
68#[derive(Debug, Clone, Serialize, Deserialize)]
69pub struct Message {
70    /// Message role.
71    pub role: Role,
72
73    /// Message content.
74    pub content: MessageContent,
75}
76
77/// Message role.
78#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
79#[serde(rename_all = "lowercase")]
80pub enum Role {
81    /// User message.
82    User,
83    /// Assistant message.
84    Assistant,
85    /// System message.
86    System,
87    /// Tool result.
88    Tool,
89}
90
91/// Message content.
92#[derive(Debug, Clone, Serialize, Deserialize)]
93#[serde(untagged)]
94pub enum MessageContent {
95    /// Simple text content.
96    Text(String),
97    /// Structured content blocks.
98    Blocks(Vec<ContentBlock>),
99}
100
101/// Content block.
102#[derive(Debug, Clone, Serialize, Deserialize)]
103#[serde(tag = "type", rename_all = "snake_case")]
104pub enum ContentBlock {
105    /// Text block.
106    Text {
107        /// Text content.
108        text: String,
109    },
110    /// Image block.
111    Image {
112        /// Image source.
113        source: ImageSource,
114    },
115    /// Tool use block.
116    ToolUse {
117        /// Tool ID.
118        id: String,
119        /// Tool name.
120        name: String,
121        /// Tool input.
122        input: serde_json::Value,
123    },
124    /// Tool result block.
125    ToolResult {
126        /// Tool use ID.
127        tool_use_id: String,
128        /// Tool result content.
129        content: String,
130        /// Whether tool errored.
131        is_error: Option<bool>,
132    },
133}
134
135/// Image source.
136#[derive(Debug, Clone, Serialize, Deserialize)]
137pub struct ImageSource {
138    /// Source type.
139    #[serde(rename = "type")]
140    pub source_type: String,
141    /// Media type.
142    pub media_type: String,
143    /// Base64 data.
144    pub data: String,
145}
146
147/// Tool definition.
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct Tool {
150    /// Tool name.
151    pub name: String,
152    /// Tool description.
153    pub description: String,
154    /// Input schema.
155    pub input_schema: serde_json::Value,
156}
157
158/// Completion response.
159#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct CompletionResponse {
161    /// Response ID.
162    pub id: String,
163
164    /// Model used.
165    pub model: String,
166
167    /// Response content.
168    pub content: Vec<ContentBlock>,
169
170    /// Stop reason.
171    pub stop_reason: Option<StopReason>,
172
173    /// Token usage.
174    pub usage: TokenUsage,
175}
176
177/// Reason the generation stopped.
178#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
179#[serde(rename_all = "snake_case")]
180pub enum StopReason {
181    /// End of turn.
182    EndTurn,
183    /// Hit max tokens.
184    MaxTokens,
185    /// Hit stop sequence.
186    StopSequence,
187    /// Tool use requested.
188    ToolUse,
189}
190
191/// Streaming chunk.
192#[derive(Debug, Clone, Serialize, Deserialize)]
193pub struct StreamingChunk {
194    /// Chunk type.
195    pub chunk_type: ChunkType,
196    /// Delta content.
197    pub delta: Option<String>,
198    /// Content block index.
199    pub index: Option<usize>,
200}
201
202/// Type of streaming chunk.
203#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
204#[serde(rename_all = "snake_case")]
205pub enum ChunkType {
206    /// Message start.
207    MessageStart,
208    /// Content block start.
209    ContentBlockStart,
210    /// Content block delta.
211    ContentBlockDelta,
212    /// Content block stop.
213    ContentBlockStop,
214    /// Message delta.
215    MessageDelta,
216    /// Message stop.
217    MessageStop,
218}
219
220/// AI provider trait.
221#[async_trait]
222pub trait Provider: Send + Sync {
223    /// Provider name.
224    fn name(&self) -> &str;
225
226    /// List available models.
227    async fn list_models(&self) -> Result<Vec<String>, ProviderError>;
228
229    /// Create a completion.
230    async fn complete(
231        &self,
232        request: CompletionRequest,
233    ) -> Result<CompletionResponse, ProviderError>;
234
235    /// Create a streaming completion.
236    async fn complete_stream(
237        &self,
238        request: CompletionRequest,
239    ) -> Result<
240        Pin<Box<dyn futures::Stream<Item = Result<StreamingChunk, ProviderError>> + Send>>,
241        ProviderError,
242    >;
243}