openai_agents_rust/model/
mod.rs1pub mod gpt_oss_responses;
2pub mod litellm;
3pub mod openai_chat;
4pub mod openai_realtime;
5
6use crate::error::AgentError;
7use async_trait::async_trait;
8use serde_json::Value;
9use std::collections::HashMap;
10use std::fmt::Debug;
11
12#[async_trait]
14pub trait Model: Send + Sync {
15 async fn generate(&self, prompt: &str) -> Result<String, AgentError>;
17
18 async fn get_response(
21 &self,
22 system_instructions: Option<&str>,
23 input: &str,
24 _model_settings: Option<HashMap<String, String>>,
25 _messages: Option<&[Value]>, _tools: Option<&[Value]>, _tool_choice: Option<Value>, _output_schema: Option<&str>,
29 _handoffs: Option<&[String]>,
30 _tracing_enabled: bool,
31 _previous_response_id: Option<&str>,
32 _prompt_config: Option<&str>,
33 ) -> Result<ModelResponse, AgentError> {
34 let text = if let Some(messages) = _messages {
36 let last_user = messages.iter().rev().find_map(|m| {
38 let role = m.get("role")?.as_str()?;
39 if role == "user" {
40 m.get("content")
41 .and_then(|c| c.as_str())
42 .map(|s| s.to_string())
43 } else {
44 None
45 }
46 });
47 match last_user {
48 Some(s) => self.generate(&s).await?,
49 None => {
50 let mut s = String::new();
51 if let Some(sys) = system_instructions {
52 s.push_str(sys);
53 s.push_str("\n\n");
54 }
55 s.push_str(input);
56 self.generate(&s).await?
57 }
58 }
59 } else {
60 let mut s = String::new();
61 if let Some(sys) = system_instructions {
62 s.push_str(sys);
63 s.push_str("\n\n");
64 }
65 s.push_str(input);
66 self.generate(&s).await?
67 };
68 Ok(ModelResponse {
69 id: None,
70 text: Some(text),
71 tool_calls: vec![],
72 })
73 }
74}
75
76#[derive(Debug, Clone)]
78pub struct ModelResponse {
79 pub id: Option<String>,
80 pub text: Option<String>,
81 pub tool_calls: Vec<ToolCall>,
82}
83
84#[derive(Debug, Clone)]
85pub struct ToolCall {
86 pub id: Option<String>,
87 pub name: String,
88 pub arguments: String,
89 pub call_id: Option<String>,
90}
91
92#[derive(Debug, Clone)]
94pub enum ModelStreamEvent {
95 TextDelta(String),
96 ToolCallDelta(ToolCall),
97}