1use async_trait::async_trait;
7use rain_engine_core::{
8 AgentAction, LlmProvider, PlannedSkillCall, ProviderDecision, ProviderError, ProviderErrorKind,
9 ProviderRequest, ProviderRequestConfig,
10};
11use reqwest::{Client, StatusCode};
12use serde::{Deserialize, Serialize};
13use serde_json::{Value, json};
14use thiserror::Error;
15
16#[derive(Debug, Clone)]
17pub struct OpenAiCompatibleConfig {
18 pub base_url: String,
19 pub api_key: String,
20 pub default_request: ProviderRequestConfig,
21 pub system_prompt: String,
22}
23
24impl OpenAiCompatibleConfig {
25 pub fn validated(&self) -> Result<(), OpenAiConfigError> {
26 if self.base_url.trim().is_empty() {
27 return Err(OpenAiConfigError::Invalid(
28 "base_url must not be empty".to_string(),
29 ));
30 }
31 if self.api_key.trim().is_empty() {
32 return Err(OpenAiConfigError::Invalid(
33 "api_key must not be empty".to_string(),
34 ));
35 }
36 Ok(())
37 }
38}
39
40#[derive(Debug, Error)]
41pub enum OpenAiConfigError {
42 #[error("{0}")]
43 Invalid(String),
44}
45
46#[derive(Clone)]
47pub struct OpenAiCompatibleProvider {
48 client: Client,
49 config: OpenAiCompatibleConfig,
50}
51
52impl OpenAiCompatibleProvider {
53 pub fn new(config: OpenAiCompatibleConfig) -> Result<Self, OpenAiConfigError> {
54 config.validated()?;
55 Ok(Self {
56 client: Client::new(),
57 config,
58 })
59 }
60}
61
62#[async_trait]
63impl LlmProvider for OpenAiCompatibleProvider {
64 async fn generate_action(
65 &self,
66 input: ProviderRequest,
67 ) -> Result<ProviderDecision, ProviderError> {
68 let model = input
69 .config
70 .model
71 .clone()
72 .or_else(|| self.config.default_request.model.clone())
73 .ok_or_else(|| {
74 ProviderError::new(
75 ProviderErrorKind::Configuration,
76 "no model configured for OpenAI-compatible provider",
77 false,
78 )
79 })?;
80
81 let request = ChatCompletionRequest {
82 model,
83 temperature: input
84 .config
85 .temperature
86 .or(self.config.default_request.temperature),
87 max_tokens: input
88 .config
89 .max_tokens
90 .or(self.config.default_request.max_tokens),
91 messages: map_to_chat_messages(&input, self.config.system_prompt.clone())?,
92 tools: input
93 .available_skills
94 .iter()
95 .map(|skill| ToolDefinition {
96 kind: "function".to_string(),
97 function: ToolFunction {
98 name: skill.manifest.name.clone(),
99 description: skill.manifest.description.clone(),
100 parameters: skill.manifest.input_schema.clone(),
101 },
102 })
103 .collect(),
104 tool_choice: Some(json!("auto")),
105 };
106
107 let response = self
108 .client
109 .post(format!(
110 "{}/chat/completions",
111 self.config.base_url.trim_end_matches('/')
112 ))
113 .bearer_auth(&self.config.api_key)
114 .json(&request)
115 .send()
116 .await
117 .map_err(|err| {
118 ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
119 })?;
120
121 if !response.status().is_success() {
122 let status = response.status();
123 let body = response.text().await.unwrap_or_default();
124 return Err(classify_status(status, body));
125 }
126
127 let raw_text = response.text().await.map_err(|err| {
128 ProviderError::new(ProviderErrorKind::Transport, err.to_string(), true)
129 })?;
130
131 let body: ChatCompletionResponse = serde_json::from_str(&raw_text).map_err(|err| {
132 tracing::error!("OpenAI response deserialization failed: {err}\nRaw body: {raw_text}");
133 ProviderError::new(
134 ProviderErrorKind::InvalidResponse,
135 format!("error decoding response body: {err}"),
136 false,
137 )
138 })?;
139
140 let choice = body.choices.into_iter().next().ok_or_else(|| {
141 ProviderError::new(
142 ProviderErrorKind::InvalidResponse,
143 "provider returned no choices",
144 false,
145 )
146 })?;
147
148 if let Some(tool_calls) = choice.message.tool_calls
149 && !tool_calls.is_empty()
150 {
151 let mut planned = Vec::with_capacity(tool_calls.len());
152 for (index, tool_call) in tool_calls.into_iter().enumerate() {
153 let args = serde_json::from_str::<Value>(&tool_call.function.arguments).map_err(
154 |err| {
155 ProviderError::new(
156 ProviderErrorKind::InvalidResponse,
157 format!("invalid tool call arguments: {err}"),
158 false,
159 )
160 },
161 )?;
162 planned.push(PlannedSkillCall {
163 call_id: tool_call
164 .id
165 .unwrap_or_else(|| format!("openai-call-{index}")),
166 name: tool_call.function.name,
167 args,
168 priority: 0,
169 depends_on: Vec::new(),
170 retry_policy: Default::default(),
171 dry_run: false,
172 });
173 }
174 return Ok(ProviderDecision {
175 action: AgentAction::CallSkills(planned),
176 usage: None,
177 cache: None,
178 });
179 }
180
181 let content = choice.message.content.unwrap_or_default();
182 if let Ok(structured) = serde_json::from_str::<StructuredAction>(&content) {
183 return Ok(ProviderDecision {
184 action: match structured.kind.as_str() {
185 "yield" => AgentAction::Yield {
186 reason: structured.content,
187 },
188 _ => AgentAction::Respond {
189 content: structured.content.unwrap_or_default(),
190 },
191 },
192 usage: None,
193 cache: None,
194 });
195 }
196
197 Ok(ProviderDecision {
198 action: if content.trim().is_empty() {
199 AgentAction::Yield { reason: None }
200 } else {
201 AgentAction::Respond { content }
202 },
203 usage: None,
204 cache: None,
205 })
206 }
207}
208
209fn map_to_chat_messages(
210 input: &ProviderRequest,
211 system_prompt: String,
212) -> Result<Vec<ChatMessage>, ProviderError> {
213 let mut messages = vec![ChatMessage::system(system_prompt)];
214 for msg in &input.contents {
215 let role = match msg.role {
216 rain_engine_core::ProviderRole::System => "system",
217 rain_engine_core::ProviderRole::User => "user",
218 rain_engine_core::ProviderRole::Assistant => "assistant",
219 rain_engine_core::ProviderRole::Tool => "tool",
220 };
221
222 let mut content = String::new();
223 let mut tool_calls = None;
224 let mut tool_call_id = None;
225
226 for part in &msg.parts {
227 match part {
228 rain_engine_core::ProviderContentPart::Text(t) => {
229 if !content.is_empty() {
230 content.push('\n');
231 }
232 content.push_str(t);
233 }
234 rain_engine_core::ProviderContentPart::Json(j) => {
235 if msg.role == rain_engine_core::ProviderRole::Assistant {
237 if let Ok(calls) =
238 serde_json::from_value::<Vec<PlannedSkillCall>>(j.clone())
239 {
240 tool_calls = Some(
241 calls
242 .into_iter()
243 .map(|c| ToolCallRequest {
244 id: c.call_id,
245 kind: "function".to_string(),
246 function: ToolFunctionCall {
247 name: c.name,
248 arguments: c.args.to_string(),
249 },
250 })
251 .collect(),
252 );
253 } else {
254 if !content.is_empty() {
255 content.push('\n');
256 }
257 content.push_str(&j.to_string());
258 }
259 } else {
260 if !content.is_empty() {
261 content.push('\n');
262 }
263 content.push_str(&j.to_string());
264 }
265 }
266 rain_engine_core::ProviderContentPart::ToolResult(r) => {
267 content.push_str(&serde_json::to_string(&r.output).unwrap_or_default());
268 tool_call_id = Some(r.call_id.clone());
269 }
270 _ => {}
271 }
272 }
273
274 messages.push(ChatMessage {
275 role: role.to_string(),
276 content: if content.is_empty() && tool_calls.is_some() {
277 None
278 } else {
279 Some(content)
280 },
281 tool_calls,
282 tool_call_id,
283 });
284 }
285 Ok(messages)
286}
287
288fn classify_status(status: StatusCode, body: String) -> ProviderError {
289 match status {
290 StatusCode::TOO_MANY_REQUESTS => {
291 ProviderError::new(ProviderErrorKind::RateLimited, body, true)
292 }
293 StatusCode::BAD_REQUEST => {
294 ProviderError::new(ProviderErrorKind::InvalidResponse, body, false)
295 }
296 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => {
297 ProviderError::new(ProviderErrorKind::Configuration, body, false)
298 }
299 _ if status.is_server_error() => {
300 ProviderError::new(ProviderErrorKind::Transport, body, true)
301 }
302 _ => ProviderError::new(ProviderErrorKind::Internal, body, false),
303 }
304}
305
306#[derive(Debug, Serialize)]
307struct ChatCompletionRequest {
308 model: String,
309 #[serde(skip_serializing_if = "Option::is_none")]
310 temperature: Option<f32>,
311 #[serde(skip_serializing_if = "Option::is_none")]
312 max_tokens: Option<u32>,
313 messages: Vec<ChatMessage>,
314 tools: Vec<ToolDefinition>,
315 #[serde(skip_serializing_if = "Option::is_none")]
316 tool_choice: Option<Value>,
317}
318
319#[derive(Debug, Serialize)]
320struct ChatMessage {
321 role: String,
322 #[serde(skip_serializing_if = "Option::is_none")]
323 content: Option<String>,
324 #[serde(skip_serializing_if = "Option::is_none")]
325 tool_calls: Option<Vec<ToolCallRequest>>,
326 #[serde(skip_serializing_if = "Option::is_none")]
327 tool_call_id: Option<String>,
328}
329
330#[derive(Debug, Serialize)]
331struct ToolCallRequest {
332 #[serde(rename = "type")]
333 kind: String,
334 id: String,
335 function: ToolFunctionCall,
336}
337
338#[derive(Debug, Serialize)]
339struct ToolFunctionCall {
340 name: String,
341 arguments: String,
342}
343
344impl ChatMessage {
345 fn system(content: String) -> Self {
346 Self {
347 role: "system".to_string(),
348 content: Some(content),
349 tool_calls: None,
350 tool_call_id: None,
351 }
352 }
353}
354
355#[derive(Debug, Serialize)]
356struct ToolDefinition {
357 #[serde(rename = "type")]
358 kind: String,
359 function: ToolFunction,
360}
361
362#[derive(Debug, Serialize)]
363struct ToolFunction {
364 name: String,
365 description: String,
366 parameters: Value,
367}
368
369#[derive(Debug, Deserialize)]
370struct ChatCompletionResponse {
371 choices: Vec<Choice>,
372}
373
374#[derive(Debug, Deserialize)]
375struct Choice {
376 message: ChoiceMessage,
377}
378
379#[derive(Debug, Deserialize)]
380struct ChoiceMessage {
381 content: Option<String>,
382 tool_calls: Option<Vec<ToolCall>>,
383}
384
385#[derive(Debug, Deserialize)]
386struct ToolCall {
387 id: Option<String>,
388 function: ToolCallFunction,
389}
390
391#[derive(Debug, Deserialize)]
392struct ToolCallFunction {
393 name: String,
394 arguments: String,
395}
396
397#[derive(Debug, Deserialize)]
398struct StructuredAction {
399 #[serde(rename = "type")]
400 kind: String,
401 content: Option<String>,
402}
403
404#[cfg(test)]
405mod tests {
406 use super::*;
407 use axum::{Json, Router, routing::post};
408 use rain_engine_core::{
409 AgentContextSnapshot, AgentId, AgentStateSnapshot, AgentTrigger, EnginePolicy,
410 ProviderContentPart, SkillDefinition, SkillManifest,
411 };
412 use serde_json::json;
413
414 fn provider_request() -> ProviderRequest {
415 ProviderRequest {
416 trigger: AgentTrigger::Message {
417 user_id: "u".to_string(),
418 content: "hello".to_string(),
419 attachments: Vec::new(),
420 },
421 context: AgentContextSnapshot {
422 session_id: "s".to_string(),
423 granted_scopes: vec!["tool:run".to_string()],
424 trigger_id: "t".to_string(),
425 idempotency_key: None,
426 current_step: 0,
427 max_steps: 8,
428 history: Vec::new(),
429 prior_tool_results: Vec::new(),
430 session_cost_usd: 0.0,
431 state: AgentStateSnapshot {
432 agent_id: AgentId("s".to_string()),
433 profile: None,
434 goals: Vec::new(),
435 tasks: Vec::new(),
436 observations: Vec::new(),
437 artifacts: Vec::new(),
438 resources: Vec::new(),
439 relationships: Vec::new(),
440 pending_wake: None,
441 },
442 policy: EnginePolicy::default(),
443 active_execution_plan: None,
444 },
445 available_skills: vec![SkillDefinition {
446 manifest: SkillManifest {
447 name: "echo".to_string(),
448 description: "Echo".to_string(),
449 input_schema: json!({"type":"object"}),
450 required_scopes: vec!["tool:run".to_string()],
451 capability_grants: vec![],
452 resource_policy: rain_engine_core::ResourcePolicy::default_for_tools(),
453 approval_required: false,
454 circuit_breaker_threshold: 0.5,
455 },
456 executor_kind: "wasm".to_string(),
457 }],
458 config: ProviderRequestConfig {
459 model: Some("test-model".to_string()),
460 temperature: Some(0.1),
461 max_tokens: Some(32),
462 },
463 policy: EnginePolicy::default(),
464 contents: vec![rain_engine_core::ProviderMessage {
465 role: rain_engine_core::ProviderRole::User,
466 parts: vec![ProviderContentPart::Text("hello".to_string())],
467 }],
468 }
469 }
470
471 async fn spawn_test_server(response_body: Value) -> String {
472 let app = Router::new().route(
473 "/chat/completions",
474 post(move || {
475 let response_body = response_body.clone();
476 async move { Json(response_body) }
477 }),
478 );
479 let listener = tokio::net::TcpListener::bind("127.0.0.1:0")
480 .await
481 .expect("bind");
482 let addr = listener.local_addr().expect("addr");
483 tokio::spawn(async move {
484 axum::serve(listener, app).await.expect("server");
485 });
486 format!("http://{}", addr)
487 }
488
489 #[tokio::test]
490 async fn parses_parallel_tool_call_response() {
491 let base_url = spawn_test_server(json!({
492 "choices": [{
493 "message": {
494 "content": null,
495 "tool_calls": [{
496 "id": "call-1",
497 "function": {
498 "name": "echo",
499 "arguments": "{\"value\":1}"
500 }
501 }, {
502 "id": "call-2",
503 "function": {
504 "name": "echo",
505 "arguments": "{\"value\":2}"
506 }
507 }]
508 }
509 }]
510 }))
511 .await;
512
513 let provider = OpenAiCompatibleProvider::new(OpenAiCompatibleConfig {
514 base_url,
515 api_key: "token".to_string(),
516 default_request: ProviderRequestConfig::default(),
517 system_prompt: "You are helpful".to_string(),
518 })
519 .expect("provider");
520
521 let decision = provider
522 .generate_action(provider_request())
523 .await
524 .expect("decision");
525 assert_eq!(
526 decision.action,
527 AgentAction::CallSkills(vec![
528 PlannedSkillCall {
529 call_id: "call-1".to_string(),
530 name: "echo".to_string(),
531 args: json!({"value": 1}),
532 priority: 0,
533 depends_on: Vec::new(),
534 retry_policy: Default::default(),
535 dry_run: false,
536 },
537 PlannedSkillCall {
538 call_id: "call-2".to_string(),
539 name: "echo".to_string(),
540 args: json!({"value": 2}),
541 priority: 0,
542 depends_on: Vec::new(),
543 retry_policy: Default::default(),
544 dry_run: false,
545 },
546 ])
547 );
548 }
549
550 #[tokio::test]
551 async fn invalid_tool_call_arguments_are_classified() {
552 let base_url = spawn_test_server(json!({
553 "choices": [{
554 "message": {
555 "content": null,
556 "tool_calls": [{
557 "function": {
558 "name": "echo",
559 "arguments": "{"
560 }
561 }]
562 }
563 }]
564 }))
565 .await;
566
567 let provider = OpenAiCompatibleProvider::new(OpenAiCompatibleConfig {
568 base_url,
569 api_key: "token".to_string(),
570 default_request: ProviderRequestConfig::default(),
571 system_prompt: "You are helpful".to_string(),
572 })
573 .expect("provider");
574
575 let error = provider
576 .generate_action(provider_request())
577 .await
578 .expect_err("error");
579 assert_eq!(error.kind, ProviderErrorKind::InvalidResponse);
580 }
581}