Skip to main content

rain_engine_openai/
lib.rs

1//! OpenAI-compatible provider adapter for RainEngine.
2//!
3//! This baseline provider maps provider-neutral requests into chat completion
4//! style tool calls.
5
6use async_trait::async_trait;
7use rain_engine_core::{
8    AgentAction, LlmProvider, PlannedSkillCall, ProviderDecision, ProviderError, ProviderErrorKind,
9    ProviderRequest, ProviderRequestConfig,
10};
11use reqwest::{Client, StatusCode};
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub struct OpenAiCompatibleConfig {
18    pub base_url: String,
19    pub api_key: String,
20    pub default_request: ProviderRequestConfig,
21    pub system_prompt: String,
22}
23
24impl OpenAiCompatibleConfig {
25    pub fn validated(&self) -> Result<(), OpenAiConfigError> {
26        if self.base_url.trim().is_empty() {
27            return Err(OpenAiConfigError::Invalid(
28                "base_url must not be empty".to_string(),
29            ));
30        }
31        if self.api_key.trim().is_empty() {
32            return Err(OpenAiConfigError::Invalid(
33                "api_key must not be empty".to_string(),
34            ));
35        }
36        Ok(())
37    }
38}
39
40#[derive(Debug, Error)]
41pub enum OpenAiConfigError {
42    #[error("{0}")]
43    Invalid(String),
44}
45
46#[derive(Clone)]
47pub struct OpenAiCompatibleProvider {
48    client: Client,
49    config: OpenAiCompatibleConfig,
50}
51
52impl OpenAiCompatibleProvider {
53    pub fn new(config: OpenAiCompatibleConfig) -> Result<Self, OpenAiConfigError> {
54        config.validated()?;
55        Ok(Self {
56            client: Client::new(),
57            config,
58        })
59    }
60}
61
62#[async_trait]
63impl LlmProvider for OpenAiCompatibleProvider {
64    async fn generate_action(
65        &self,
66        input: ProviderRequest,
67    ) -> Result<ProviderDecision, ProviderError> {
68        let model = input
69            .config
70            .model
71            .clone()
72            .or_else(|| self.config.default_request.model.clone())
73            .ok_or_else(|| {
74                ProviderError::new(
75                    ProviderErrorKind::Configuration,
76                    "no model configured for OpenAI-compatible provider",
77                    false,
78                )
79            })?;
80
81        let request = ChatCompletionRequest {
82            model,
83            temperature: input
84                .config
85                .temperature
86                .or(self.config.default_request.temperature),
87            max_tokens: input
88                .config
89                .max_tokens
90                .or(self.config.default_request.max_tokens),
91            messages: map_to_chat_messages(&input, self.config.system_prompt.clone())?,
92            tools: input
93                .available_skills
94                .iter()
95                .map(|skill| ToolDefinition {
96                    kind: "function".to_string(),
97                    function: ToolFunction {
98                        name: skill.manifest.name.clone(),
99                        description: skill.manifest.description.clone(),
100                        parameters: skill.manifest.input_schema.clone(),
101                    },
102                })
103                .collect(),
104            tool_choice: Some(json!("auto")),
105        };
106
107        let response = self
108            .client
109            .post(format!(
110                "{}/chat/completions",
111                self.config.base_url.trim_end_matches('/')
112            ))
113            .bearer_auth(&self.config.api_key)
114            .json(&request)
115            .send()
116            .await
117            .map_err(|err| {
118                ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
119            })?;
120
121        if !response.status().is_success() {
122            let status = response.status();
123            let body = response.text().await.unwrap_or_default();
124            return Err(classify_status(status, body));
125        }
126
127        let raw_text = response.text().await.map_err(|err| {
128            ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
129        })?;
130
131        let body: ChatCompletionResponse = serde_json::from_str(&raw_text).map_err(|err| {
132            tracing::error!("OpenAI response deserialization failed: {err}\nRaw body: {raw_text}");
133            ProviderError::new(
134                ProviderErrorKind::InvalidResponse,
135                format!("error decoding response body: {err}"),
136                false,
137            )
138        })?;
139
140        let choice = body.choices.into_iter().next().ok_or_else(|| {
141            ProviderError::new(
142                ProviderErrorKind::InvalidResponse,
143                "provider returned no choices",
144                false,
145            )
146        })?;
147
148        if let Some(tool_calls) = choice.message.tool_calls
149            && !tool_calls.is_empty()
150        {
151            let mut planned = Vec::with_capacity(tool_calls.len());
152            for (index, tool_call) in tool_calls.into_iter().enumerate() {
153                let args = serde_json::from_str::<Value>(&tool_call.function.arguments).map_err(
154                    |err| {
155                        ProviderError::new(
156                            ProviderErrorKind::InvalidResponse,
157                            format!("invalid tool call arguments: {err}"),
158                            false,
159                        )
160                    },
161                )?;
162                planned.push(PlannedSkillCall {
163                    call_id: tool_call
164                        .id
165                        .unwrap_or_else(|| format!("openai-call-{index}")),
166                    name: tool_call.function.name,
167                    args,
168                    priority: 0,
169                    depends_on: Vec::new(),
170                    retry_policy: Default::default(),
171                    dry_run: false,
172                });
173            }
174            return Ok(ProviderDecision {
175                action: AgentAction::CallSkills(planned),
176                usage: None,
177                cache: None,
178            });
179        }
180
181        let content = choice.message.content.unwrap_or_default();
182        if let Ok(structured) = serde_json::from_str::<StructuredAction>(&content) {
183            return Ok(ProviderDecision {
184                action: match structured.kind.as_str() {
185                    "yield" => AgentAction::Yield {
186                        reason: structured.content,
187                    },
188                    _ => AgentAction::Respond {
189                        content: structured.content.unwrap_or_default(),
190                    },
191                },
192                usage: None,
193                cache: None,
194            });
195        }
196
197        Ok(ProviderDecision {
198            action: if content.trim().is_empty() {
199                AgentAction::Yield { reason: None }
200            } else {
201                AgentAction::Respond { content }
202            },
203            usage: None,
204            cache: None,
205        })
206    }
207}
208
209fn map_to_chat_messages(
210    input: &ProviderRequest,
211    system_prompt: String,
212) -> Result<Vec<ChatMessage>, ProviderError> {
213    let mut messages = vec![ChatMessage::system(system_prompt)];
214    for msg in &input.contents {
215        let role = match msg.role {
216            rain_engine_core::ProviderRole::System => "system",
217            rain_engine_core::ProviderRole::User => "user",
218            rain_engine_core::ProviderRole::Assistant => "assistant",
219            rain_engine_core::ProviderRole::Tool => "tool",
220        };
221
222        let mut content = String::new();
223        let mut tool_calls = None;
224        let mut tool_call_id = None;
225
226        for part in &msg.parts {
227            match part {
228                rain_engine_core::ProviderContentPart::Text(t) => {
229                    if !content.is_empty() {
230                        content.push('\n');
231                    }
232                    content.push_str(t);
233                }
234                rain_engine_core::ProviderContentPart::Json(j) => {
235                    // Try to parse as tool calls if it's an assistant message
236                    if msg.role == rain_engine_core::ProviderRole::Assistant {
237                        if let Ok(calls) =
238                            serde_json::from_value::<Vec<PlannedSkillCall>>(j.clone())
239                        {
240                            tool_calls = Some(
241                                calls
242                                    .into_iter()
243                                    .map(|c| ToolCallRequest {
244                                        id: c.call_id,
245                                        kind: "function".to_string(),
246                                        function: ToolFunctionCall {
247                                            name: c.name,
248                                            arguments: c.args.to_string(),
249                                        },
250                                    })
251                                    .collect(),
252                            );
253                        } else {
254                            if !content.is_empty() {
255                                content.push('\n');
256                            }
257                            content.push_str(&j.to_string());
258                        }
259                    } else {
260                        if !content.is_empty() {
261                            content.push('\n');
262                        }
263                        content.push_str(&j.to_string());
264                    }
265                }
266                rain_engine_core::ProviderContentPart::ToolResult(r) => {
267                    content.push_str(&serde_json::to_string(&r.output).unwrap_or_default());
268                    tool_call_id = Some(r.call_id.clone());
269                }
270                _ => {}
271            }
272        }
273
274        messages.push(ChatMessage {
275            role: role.to_string(),
276            content: if content.is_empty() && tool_calls.is_some() {
277                None
278            } else {
279                Some(content)
280            },
281            tool_calls,
282            tool_call_id,
283        });
284    }
285    Ok(messages)
286}
287
288fn classify_status(status: StatusCode, body: String) -> ProviderError {
289    match status {
290        StatusCode::TOO_MANY_REQUESTS => {
291            ProviderError::new(ProviderErrorKind::RateLimited, body, true)
292        }
293        StatusCode::BAD_REQUEST => {
294            ProviderError::new(ProviderErrorKind::InvalidResponse, body, false)
295        }
296        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
297            ProviderError::new(ProviderErrorKind::Configuration, body, false)
298        }
299        _ if status.is_server_error() => {
300            ProviderError::new(ProviderErrorKind::Transport, body, true)
301        }
302        _ => ProviderError::new(ProviderErrorKind::Internal, body, false),
303    }
304}
305
306#[derive(Debug, Serialize)]
307struct ChatCompletionRequest {
308    model: String,
309    #[serde(skip_serializing_if = "Option::is_none")]
310    temperature: Option<f32>,
311    #[serde(skip_serializing_if = "Option::is_none")]
312    max_tokens: Option<u32>,
313    messages: Vec<ChatMessage>,
314    tools: Vec<ToolDefinition>,
315    #[serde(skip_serializing_if = "Option::is_none")]
316    tool_choice: Option<Value>,
317}
318
319#[derive(Debug, Serialize)]
320struct ChatMessage {
321    role: String,
322    #[serde(skip_serializing_if = "Option::is_none")]
323    content: Option<String>,
324    #[serde(skip_serializing_if = "Option::is_none")]
325    tool_calls: Option<Vec<ToolCallRequest>>,
326    #[serde(skip_serializing_if = "Option::is_none")]
327    tool_call_id: Option<String>,
328}
329
330#[derive(Debug, Serialize)]
331struct ToolCallRequest {
332    #[serde(rename = "type")]
333    kind: String,
334    id: String,
335    function: ToolFunctionCall,
336}
337
338#[derive(Debug, Serialize)]
339struct ToolFunctionCall {
340    name: String,
341    arguments: String,
342}
343
344impl ChatMessage {
345    fn system(content: String) -> Self {
346        Self {
347            role: "system".to_string(),
348            content: Some(content),
349            tool_calls: None,
350            tool_call_id: None,
351        }
352    }
353}
354
355#[derive(Debug, Serialize)]
356struct ToolDefinition {
357    #[serde(rename = "type")]
358    kind: String,
359    function: ToolFunction,
360}
361
362#[derive(Debug, Serialize)]
363struct ToolFunction {
364    name: String,
365    description: String,
366    parameters: Value,
367}
368
369#[derive(Debug, Deserialize)]
370struct ChatCompletionResponse {
371    choices: Vec<Choice>,
372}
373
374#[derive(Debug, Deserialize)]
375struct Choice {
376    message: ChoiceMessage,
377}
378
379#[derive(Debug, Deserialize)]
380struct ChoiceMessage {
381    content: Option<String>,
382    tool_calls: Option<Vec<ToolCall>>,
383}
384
385#[derive(Debug, Deserialize)]
386struct ToolCall {
387    id: Option<String>,
388    function: ToolCallFunction,
389}
390
391#[derive(Debug, Deserialize)]
392struct ToolCallFunction {
393    name: String,
394    arguments: String,
395}
396
397#[derive(Debug, Deserialize)]
398struct StructuredAction {
399    #[serde(rename = "type")]
400    kind: String,
401    content: Option<String>,
402}
403
404#[cfg(test)]
405mod tests {
406    use super::*;
407    use axum::{Json, Router, routing::post};
408    use rain_engine_core::{
409        AgentContextSnapshot, AgentId, AgentStateSnapshot, AgentTrigger, EnginePolicy,
410        ProviderContentPart, SkillDefinition, SkillManifest,
411    };
412    use serde_json::json;
413
414    fn provider_request() -> ProviderRequest {
415        ProviderRequest {
416            trigger: AgentTrigger::Message {
417                user_id: "u".to_string(),
418                content: "hello".to_string(),
419                attachments: Vec::new(),
420            },
421            context: AgentContextSnapshot {
422                session_id: "s".to_string(),
423                granted_scopes: vec!["tool:run".to_string()],
424                trigger_id: "t".to_string(),
425                idempotency_key: None,
426                current_step: 0,
427                max_steps: 8,
428                history: Vec::new(),
429                prior_tool_results: Vec::new(),
430                session_cost_usd: 0.0,
431                state: AgentStateSnapshot {
432                    agent_id: AgentId("s".to_string()),
433                    profile: None,
434                    goals: Vec::new(),
435                    tasks: Vec::new(),
436                    observations: Vec::new(),
437                    artifacts: Vec::new(),
438                    resources: Vec::new(),
439                    relationships: Vec::new(),
440                    pending_wake: None,
441                },
442                policy: EnginePolicy::default(),
443                active_execution_plan: None,
444            },
445            available_skills: vec![SkillDefinition {
446                manifest: SkillManifest {
447                    name: "echo".to_string(),
448                    description: "Echo".to_string(),
449                    input_schema: json!({"type":"object"}),
450                    required_scopes: vec!["tool:run".to_string()],
451                    capability_grants: vec![],
452                    resource_policy: rain_engine_core::ResourcePolicy::default_for_tools(),
453                    approval_required: false,
454                    circuit_breaker_threshold: 0.5,
455                },
456                executor_kind: "wasm".to_string(),
457            }],
458            config: ProviderRequestConfig {
459                model: Some("test-model".to_string()),
460                temperature: Some(0.1),
461                max_tokens: Some(32),
462            },
463            policy: EnginePolicy::default(),
464            contents: vec![rain_engine_core::ProviderMessage {
465                role: rain_engine_core::ProviderRole::User,
466                parts: vec![ProviderContentPart::Text("hello".to_string())],
467            }],
468        }
469    }
470
471    async fn spawn_test_server(response_body: Value) -> String {
472        let app = Router::new().route(
473            "/chat/completions",
474            post(move || {
475                let response_body = response_body.clone();
476                async move { Json(response_body) }
477            }),
478        );
479        let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
480            .await
481            .expect("bind");
482        let addr = listener.local_addr().expect("addr");
483        tokio::spawn(async move {
484            axum::serve(listener, app).await.expect("server");
485        });
486        format!("http://{}", addr)
487    }
488
489    #[tokio::test]
490    async fn parses_parallel_tool_call_response() {
491        let base_url = spawn_test_server(json!({
492            "choices": [{
493                "message": {
494                    "content": null,
495                    "tool_calls": [{
496                        "id": "call-1",
497                        "function": {
498                            "name": "echo",
499                            "arguments": "{\"value\":1}"
500                        }
501                    }, {
502                        "id": "call-2",
503                        "function": {
504                            "name": "echo",
505                            "arguments": "{\"value\":2}"
506                        }
507                    }]
508                }
509            }]
510        }))
511        .await;
512
513        let provider = OpenAiCompatibleProvider::new(OpenAiCompatibleConfig {
514            base_url,
515            api_key: "token".to_string(),
516            default_request: ProviderRequestConfig::default(),
517            system_prompt: "You are helpful".to_string(),
518        })
519        .expect("provider");
520
521        let decision = provider
522            .generate_action(provider_request())
523            .await
524            .expect("decision");
525        assert_eq!(
526            decision.action,
527            AgentAction::CallSkills(vec![
528                PlannedSkillCall {
529                    call_id: "call-1".to_string(),
530                    name: "echo".to_string(),
531                    args: json!({"value": 1}),
532                    priority: 0,
533                    depends_on: Vec::new(),
534                    retry_policy: Default::default(),
535                    dry_run: false,
536                },
537                PlannedSkillCall {
538                    call_id: "call-2".to_string(),
539                    name: "echo".to_string(),
540                    args: json!({"value": 2}),
541                    priority: 0,
542                    depends_on: Vec::new(),
543                    retry_policy: Default::default(),
544                    dry_run: false,
545                },
546            ])
547        );
548    }
549
550    #[tokio::test]
551    async fn invalid_tool_call_arguments_are_classified() {
552        let base_url = spawn_test_server(json!({
553            "choices": [{
554                "message": {
555                    "content": null,
556                    "tool_calls": [{
557                        "function": {
558                            "name": "echo",
559                            "arguments": "{"
560                        }
561                    }]
562                }
563            }]
564        }))
565        .await;
566
567        let provider = OpenAiCompatibleProvider::new(OpenAiCompatibleConfig {
568            base_url,
569            api_key: "token".to_string(),
570            default_request: ProviderRequestConfig::default(),
571            system_prompt: "You are helpful".to_string(),
572        })
573        .expect("provider");
574
575        let error = provider
576            .generate_action(provider_request())
577            .await
578            .expect_err("error");
579        assert_eq!(error.kind, ProviderErrorKind::InvalidResponse);
580    }
581}