use async_stream::stream;
use futures_util::StreamExt;
use reqwest::{Client, blocking::Client as BlockingClient};
use serde_json::{Value, json};
use std::collections::{BTreeMap, BTreeSet};
use std::sync::{OnceLock, RwLock};
use crate::{
EmbeddingModel, EmbeddingResult, Error, FinishReason, LanguageModel, ModelMessage,
ModelRequest, ModelResponse, Part, ProviderRegistration, Result, Role, ToolChoice,
ToolDefinition, ToolSchema, Usage,
};
#[derive(Clone, Debug)]
pub struct OpenAiLanguageModel {
provider: OpenAiProvider,
model_id: String,
}
#[derive(Clone, Debug)]
pub struct OpenAiEmbeddingModel {
provider: OpenAiProvider,
model_id: String,
}
#[derive(Clone, Debug, PartialEq, Eq)]
pub struct OpenAiProvider {
pub id: String,
pub name: String,
pub base_url: String,
pub api_key: Option<String>,
pub api_key_env: &'static str,
}
#[derive(Clone, Debug)]
struct RegisteredOpenAiProvider {
id: &'static str,
provider: OpenAiProvider,
}
impl OpenAiProvider {
pub fn new<I, N>(id: I, name: N) -> Self
where
I: Into<String>,
N: Into<String>,
{
Self {
id: id.into(),
name: name.into(),
base_url: "https://api.openai.com/v1".to_string(),
api_key: None,
api_key_env: "OPENAI_API_KEY",
}
}
pub fn base_url(mut self, base_url: impl Into<String>) -> Self {
self.base_url = base_url.into().trim_end_matches('/').to_string();
self
}
pub fn api_key(mut self, api_key: impl Into<String>) -> Self {
self.api_key = Some(api_key.into());
self
}
pub fn api_key_env(mut self, api_key_env: &'static str) -> Self {
self.api_key_env = api_key_env;
self
}
}
fn default_openai_provider() -> OpenAiProvider {
OpenAiProvider::new("openai", "OpenAI")
}
fn registered_openai_providers() -> &'static RwLock<BTreeMap<String, RegisteredOpenAiProvider>> {
static PROVIDERS: OnceLock<RwLock<BTreeMap<String, RegisteredOpenAiProvider>>> =
OnceLock::new();
PROVIDERS.get_or_init(|| RwLock::new(BTreeMap::new()))
}
pub fn register(provider: OpenAiProvider) -> Result<()> {
if inventory::iter::<ProviderRegistration>
.into_iter()
.any(|registered| registered.id == provider.id)
{
return Err(Error::DuplicateProvider(provider.id));
}
let mut providers = registered_openai_providers()
.write()
.expect("custom OpenAI provider registry poisoned");
if providers.contains_key(&provider.id) {
return Err(Error::DuplicateProvider(provider.id));
}
let id = Box::leak(provider.id.clone().into_boxed_str());
providers.insert(
provider.id.clone(),
RegisteredOpenAiProvider { id, provider },
);
Ok(())
}
pub fn registered() -> Vec<OpenAiProvider> {
let mut providers = vec![default_openai_provider()];
providers.extend(
registered_openai_providers()
.read()
.expect("custom OpenAI provider registry poisoned")
.values()
.map(|provider| provider.provider.clone()),
);
providers.sort_by(|left, right| left.id.cmp(&right.id));
providers
}
pub(crate) fn registered_provider_ids() -> Vec<&'static str> {
registered_openai_providers()
.read()
.expect("custom OpenAI provider registry poisoned")
.values()
.map(|provider| provider.id)
.collect()
}
pub(crate) fn resolve_language_model(
provider_id: &str,
model_id: &str,
) -> Option<Result<Box<dyn LanguageModel>>> {
let provider = registered_openai_providers()
.read()
.expect("custom OpenAI provider registry poisoned")
.get(provider_id)
.map(|provider| provider.provider.clone())?;
Some(openai_language_model_with_provider(provider, model_id))
}
pub(crate) fn resolve_embedding_model(
provider_id: &str,
model_id: &str,
) -> Option<Result<Box<dyn EmbeddingModel>>> {
let provider = registered_openai_providers()
.read()
.expect("custom OpenAI provider registry poisoned")
.get(provider_id)
.map(|provider| provider.provider.clone())?;
Some(openai_embedding_model_with_provider(provider, model_id))
}
impl LanguageModel for OpenAiLanguageModel {
fn model_id(&self) -> &str {
&self.model_id
}
fn generate<'a>(&'a self, request: &'a ModelRequest) -> crate::ModelFuture<'a, ModelResponse> {
Box::pin(async move {
let (status, body) = openai_post_json(
&self.provider,
"chat/completions",
openai_chat_request(&self.model_id, request),
)
.await?;
if !(200..300).contains(&status) {
return Err(Error::Api(openai_error_message(&body)));
}
openai_chat_response_to_model_response(&self.provider.id, &self.model_id, &body)
})
}
fn stream<'a>(
&'a self,
request: &'a ModelRequest,
) -> crate::ModelFuture<'a, crate::StreamTextStream> {
Box::pin(async move {
let response = openai_post_stream(
&self.provider,
"chat/completions",
openai_stream_request(&self.model_id, request),
)
.await?;
openai_stream_response_to_events(response).await
})
}
}
impl EmbeddingModel for OpenAiEmbeddingModel {
fn model_id(&self) -> &str {
&self.model_id
}
fn embed(&self, value: &str) -> Result<EmbeddingResult> {
let (status, body) = openai_post_json_blocking(
&self.provider,
"embeddings",
json!({
"model": self.model_id,
"input": value,
}),
)?;
if !(200..300).contains(&status) {
return Err(Error::Api(openai_error_message(&body)));
}
let embedding = body
.get("data")
.and_then(Value::as_array)
.and_then(|items| items.first())
.and_then(|item| item.get("embedding"))
.and_then(Value::as_array)
.ok_or_else(|| Error::Parse("missing embedding data".to_string()))?
.iter()
.map(|value| {
value
.as_f64()
.map(|value| value as f32)
.ok_or_else(|| Error::Parse("invalid embedding value".to_string()))
})
.collect::<Result<Vec<_>>>()?;
Ok(EmbeddingResult {
embedding,
usage: openai_usage(&body),
})
}
}
fn openai_language_model(model_id: &str) -> Result<Box<dyn LanguageModel>> {
openai_language_model_with_provider(default_openai_provider(), model_id)
}
fn openai_language_model_with_provider(
provider: OpenAiProvider,
model_id: &str,
) -> Result<Box<dyn LanguageModel>> {
if model_id.is_empty() {
return Err(Error::UnsupportedModel(format!("{}/", provider.id)));
}
Ok(Box::new(OpenAiLanguageModel {
provider,
model_id: model_id.to_string(),
}) as Box<dyn LanguageModel>)
}
fn openai_embedding_model(model_id: &str) -> Result<Box<dyn EmbeddingModel>> {
openai_embedding_model_with_provider(default_openai_provider(), model_id)
}
fn openai_embedding_model_with_provider(
provider: OpenAiProvider,
model_id: &str,
) -> Result<Box<dyn EmbeddingModel>> {
if model_id.is_empty() {
return Err(Error::UnsupportedModel(format!("{}/", provider.id)));
}
Ok(Box::new(OpenAiEmbeddingModel {
provider,
model_id: model_id.to_string(),
}) as Box<dyn EmbeddingModel>)
}
inventory::submit! {
ProviderRegistration {
id: "openai",
language_model: openai_language_model,
embedding_model: openai_embedding_model,
}
}
fn openai_api_key(provider: &OpenAiProvider) -> Result<String> {
if let Some(api_key) = &provider.api_key {
return Ok(api_key.clone());
}
std::env::var(provider.api_key_env)
.map_err(|_| Error::MissingEnvironmentVariable(provider.api_key_env))
}
fn openai_url(provider: &OpenAiProvider, path: &str) -> String {
format!("{}/{}", provider.base_url, path)
}
fn openai_post_json_blocking(
provider: &OpenAiProvider,
path: &str,
body: Value,
) -> Result<(u16, Value)> {
let api_key = openai_api_key(provider)?;
let url = openai_url(provider, path);
std::thread::spawn(move || {
let response = BlockingClient::builder()
.build()
.map_err(|error| Error::Http(error.to_string()))?
.post(url)
.bearer_auth(api_key)
.json(&body)
.send()
.map_err(|error| Error::Http(error.to_string()))?;
let status = response.status().as_u16();
let body = response
.json()
.map_err(|error| Error::Json(error.to_string()))?;
Ok((status, body))
})
.join()
.map_err(|_| Error::Http("openai request thread panicked".to_string()))?
}
async fn openai_post_json(
provider: &OpenAiProvider,
path: &str,
body: Value,
) -> Result<(u16, Value)> {
let api_key = openai_api_key(provider)?;
let url = openai_url(provider, path);
let response = Client::builder()
.build()
.map_err(|error| Error::Http(error.to_string()))?
.post(url)
.bearer_auth(api_key)
.json(&body)
.send()
.await
.map_err(|error| Error::Http(error.to_string()))?;
let status = response.status().as_u16();
let body = response
.json()
.await
.map_err(|error| Error::Json(error.to_string()))?;
Ok((status, body))
}
async fn openai_post_stream(
provider: &OpenAiProvider,
path: &str,
body: Value,
) -> Result<reqwest::Response> {
let api_key = openai_api_key(provider)?;
let url = openai_url(provider, path);
let response = Client::builder()
.build()
.map_err(|error| Error::Http(error.to_string()))?
.post(url)
.bearer_auth(api_key)
.json(&body)
.send()
.await
.map_err(|error| Error::Http(error.to_string()))?;
if response.status().is_success() {
return Ok(response);
}
let body = response
.json::<Value>()
.await
.map_err(|error| Error::Json(error.to_string()))?;
Err(Error::Api(openai_error_message(&body)))
}
fn openai_chat_request(model_id: &str, request: &ModelRequest) -> Value {
let mut body = json!({
"model": model_id,
"messages": openai_messages(&request.messages),
});
if let Some(temperature) = request.settings.temperature {
body["temperature"] = json!(temperature);
}
if let Some(max_tokens) = request.settings.max_output_tokens {
body["max_tokens"] = json!(max_tokens);
}
if !request.tools.is_empty() {
body["tools"] = Value::Array(request.tools.iter().map(openai_tool_definition).collect());
body["tool_choice"] = openai_tool_choice(&request.tool_choice);
}
body
}
fn openai_stream_request(model_id: &str, request: &ModelRequest) -> Value {
let mut body = openai_chat_request(model_id, request);
body["stream"] = json!(true);
body["stream_options"] = json!({ "include_usage": true });
body
}
#[derive(Clone, Default)]
struct PartialToolCall {
id: String,
name: String,
input: String,
}
impl PartialToolCall {
fn into_tool_call(self) -> Result<crate::ToolCall> {
if self.id.is_empty() || self.name.is_empty() {
return Err(Error::Parse("incomplete streamed tool call".to_string()));
}
Ok(crate::ToolCall {
id: self.id,
name: self.name,
input: self.input,
})
}
}
struct OpenAiStreamState {
events: Vec<crate::StreamEvent>,
tool_calls: BTreeMap<usize, PartialToolCall>,
part_order: Vec<StreamPart>,
seen_tool_call_indices: BTreeSet<usize>,
finish_reason: FinishReason,
usage: Usage,
}
impl Default for OpenAiStreamState {
fn default() -> Self {
Self {
events: Vec::new(),
tool_calls: BTreeMap::new(),
part_order: Vec::new(),
seen_tool_call_indices: BTreeSet::new(),
finish_reason: FinishReason::Stop,
usage: Usage::default(),
}
}
}
impl OpenAiStreamState {
fn take_events(&mut self) -> Vec<crate::StreamEvent> {
std::mem::take(&mut self.events)
}
fn push_text(&mut self, text: String) {
self.part_order.push(StreamPart::Text(text.clone()));
self.events.push(crate::StreamEvent::TextDelta(text));
}
fn note_tool_call(&mut self, index: usize) {
if self.seen_tool_call_indices.insert(index) {
self.part_order.push(StreamPart::ToolCall(index));
}
}
fn finish_parts(&self) -> Result<Vec<Part>> {
let mut parts = Vec::with_capacity(self.part_order.len());
for part in &self.part_order {
match part {
StreamPart::Text(text) => match parts.last_mut() {
Some(Part::Text(existing)) => existing.push_str(text),
_ => parts.push(Part::Text(text.clone())),
},
StreamPart::ToolCall(index) => {
let call = self
.tool_calls
.get(index)
.ok_or_else(|| Error::Parse("missing streamed tool call".to_string()))?
.clone()
.into_tool_call()?;
parts.push(Part::ToolCall(call));
}
}
}
Ok(parts)
}
}
#[derive(Clone, Debug)]
enum StreamPart {
Text(String),
ToolCall(usize),
}
async fn openai_stream_response_to_events(
response: reqwest::Response,
) -> Result<crate::StreamTextStream> {
let stream = stream! {
let mut state = OpenAiStreamState::default();
let mut response_stream = response.bytes_stream();
let mut buffer = String::new();
let mut done = false;
while let Some(chunk) = response_stream.next().await {
let chunk = match chunk {
Ok(chunk) => chunk,
Err(error) => {
yield crate::StreamEvent::Error(Error::Http(error.to_string()).to_string());
return;
}
};
let text = match std::str::from_utf8(&chunk) {
Ok(text) => text,
Err(error) => {
yield crate::StreamEvent::Error(
Error::Parse(format!("invalid OpenAI stream chunk: {error}")).to_string(),
);
return;
}
};
buffer.push_str(text);
while let Some(frame) = next_sse_frame(&mut buffer) {
match openai_process_sse_frame(&frame, &mut state) {
Ok(is_done) => {
for event in state.take_events() {
yield event;
}
if is_done {
done = true;
break;
}
}
Err(error) => {
yield crate::StreamEvent::Error(error.to_string());
return;
}
}
}
if done {
break;
}
}
if !done && !buffer.trim().is_empty() {
match openai_process_sse_frame(buffer.trim(), &mut state) {
Ok(_) => {
for event in state.take_events() {
yield event;
}
}
Err(error) => {
yield crate::StreamEvent::Error(error.to_string());
return;
}
}
}
let parts = match state.finish_parts() {
Ok(parts) => parts,
Err(error) => {
yield crate::StreamEvent::Error(error.to_string());
return;
}
};
for part in &parts {
if let Part::ToolCall(call) = part {
yield crate::StreamEvent::ToolCall(call.clone());
}
}
yield crate::StreamEvent::Finish {
reason: state.finish_reason,
usage: state.usage,
parts,
};
};
Ok(Box::pin(stream))
}
fn next_sse_frame(buffer: &mut String) -> Option<String> {
let lf = buffer.find("\n\n").map(|index| (index, 2));
let crlf = buffer.find("\r\n\r\n").map(|index| (index, 4));
let (frame_end, delimiter_len) = match (lf, crlf) {
(Some(left), Some(right)) => left.min(right),
(Some(left), None) => left,
(None, Some(right)) => right,
(None, None) => return None,
};
let frame = buffer[..frame_end].to_string();
buffer.drain(..frame_end + delimiter_len);
Some(frame)
}
fn openai_process_sse_frame(frame: &str, state: &mut OpenAiStreamState) -> Result<bool> {
for line in frame.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with(':') || !line.starts_with("data:") {
continue;
}
let payload = line[5..].trim();
if payload == "[DONE]" {
return Ok(true);
}
openai_process_stream_payload(payload, state)?;
}
Ok(false)
}
fn openai_process_stream_payload(payload: &str, state: &mut OpenAiStreamState) -> Result<()> {
let body =
serde_json::from_str::<Value>(payload).map_err(|error| Error::Json(error.to_string()))?;
let usage = body.get("usage");
if usage.is_some() {
state.usage = openai_usage(&body);
}
if let Some(choices) = body.get("choices").and_then(Value::as_array) {
for choice in choices {
if let Some(delta) = choice.get("delta") {
if let Some(content) = delta.get("content") {
let text = openai_text_content(content);
if !text.is_empty() {
state.push_text(text);
}
}
if let Some(tool_calls) = delta.get("tool_calls").and_then(Value::as_array) {
for tool_call in tool_calls {
let index = tool_call
.get("index")
.and_then(Value::as_u64)
.unwrap_or_default() as usize;
state.note_tool_call(index);
let entry = state.tool_calls.entry(index).or_default();
if let Some(id) = tool_call.get("id").and_then(Value::as_str) {
entry.id = id.to_string();
}
if let Some(function) = tool_call.get("function") {
if let Some(name) = function.get("name").and_then(Value::as_str) {
entry.name = name.to_string();
}
if let Some(arguments) =
function.get("arguments").and_then(Value::as_str)
{
entry.input.push_str(arguments);
}
}
}
}
}
if let Some(reason) = choice.get("finish_reason").and_then(Value::as_str) {
state.finish_reason = openai_finish_reason(Some(reason));
}
}
}
Ok(())
}
fn openai_messages(messages: &[ModelMessage]) -> Vec<Value> {
messages
.iter()
.map(|message| match message.role {
Role::System => json!({ "role": "system", "content": message.text() }),
Role::User => json!({ "role": "user", "content": message.text() }),
Role::Assistant => {
let content = message.text();
let tool_calls = message
.parts
.iter()
.filter_map(|part| match part {
Part::ToolCall(call) => Some(json!({
"id": call.id,
"type": "function",
"function": {
"name": call.name,
"arguments": call.input,
}
})),
_ => None,
})
.collect::<Vec<_>>();
if tool_calls.is_empty() {
json!({ "role": "assistant", "content": content })
} else {
json!({
"role": "assistant",
"content": if content.is_empty() { Value::Null } else { Value::String(content) },
"tool_calls": tool_calls,
})
}
}
Role::Tool => {
let result = message.parts.iter().find_map(|part| match part {
Part::ToolResult(result) => Some(result),
_ => None,
});
match result {
Some(result) => json!({
"role": "tool",
"tool_call_id": result.call_id,
"content": result.output,
}),
None => json!({ "role": "tool", "content": message.text() }),
}
}
})
.collect()
}
fn openai_tool_definition(tool: &ToolDefinition) -> Value {
json!({
"type": "function",
"function": {
"name": tool.name,
"description": tool.description,
"parameters": tool_schema_json(&tool.input_schema),
}
})
}
fn openai_tool_choice(tool_choice: &ToolChoice) -> Value {
match tool_choice {
ToolChoice::Auto => json!("auto"),
ToolChoice::None => json!("none"),
ToolChoice::Required(name) => json!({
"type": "function",
"function": { "name": name }
}),
}
}
pub(crate) fn tool_schema_json(schema: &ToolSchema) -> Value {
match schema {
ToolSchema::String { description } => {
json_with_description(json!({ "type": "string" }), description)
}
ToolSchema::Integer { description } => {
json_with_description(json!({ "type": "integer" }), description)
}
ToolSchema::Number { description } => {
json_with_description(json!({ "type": "number" }), description)
}
ToolSchema::Boolean { description } => {
json_with_description(json!({ "type": "boolean" }), description)
}
ToolSchema::Array { description, items } => json_with_description(
json!({ "type": "array", "items": tool_schema_json(items) }),
description,
),
ToolSchema::Object(object) => {
let properties = object
.fields
.iter()
.map(|field| {
let mut schema = tool_schema_json(&field.schema);
if let Some(description) = &field.description {
schema["description"] = json!(description);
}
(field.name.clone(), schema)
})
.collect::<serde_json::Map<String, Value>>();
let required = object
.fields
.iter()
.filter(|field| field.required)
.map(|field| Value::String(field.name.clone()))
.collect::<Vec<_>>();
json_with_description(
json!({
"type": "object",
"properties": properties,
"required": required,
"additionalProperties": false,
}),
&object.description,
)
}
}
}
fn json_with_description(mut value: Value, description: &Option<String>) -> Value {
if let Some(description) = description {
value["description"] = json!(description);
}
value
}
fn openai_chat_response_to_model_response(
provider_id: &str,
model_id: &str,
body: &Value,
) -> Result<ModelResponse> {
let choice = body
.get("choices")
.and_then(Value::as_array)
.and_then(|choices| choices.first())
.ok_or_else(|| Error::Parse("missing choice".to_string()))?;
let message = choice
.get("message")
.ok_or_else(|| Error::Parse("missing message".to_string()))?;
let mut parts = Vec::new();
if let Some(content) = message.get("content") {
let text = openai_text_content(content);
if !text.is_empty() {
parts.push(Part::Text(text));
}
}
if let Some(tool_calls) = message.get("tool_calls").and_then(Value::as_array) {
for tool_call in tool_calls {
let id = tool_call
.get("id")
.and_then(Value::as_str)
.ok_or_else(|| Error::Parse("missing tool call id".to_string()))?;
let function = tool_call
.get("function")
.ok_or_else(|| Error::Parse("missing tool call function".to_string()))?;
let name = function
.get("name")
.and_then(Value::as_str)
.ok_or_else(|| Error::Parse("missing tool call name".to_string()))?;
let input = function
.get("arguments")
.and_then(Value::as_str)
.ok_or_else(|| Error::Parse("missing tool call arguments".to_string()))?;
parts.push(Part::ToolCall(crate::ToolCall {
id: id.to_string(),
name: name.to_string(),
input: input.to_string(),
}));
}
}
Ok(ModelResponse {
parts,
finish_reason: openai_finish_reason(choice.get("finish_reason").and_then(Value::as_str)),
usage: openai_usage(body),
response_metadata: crate::metadata_with_provider(provider_id, model_id),
})
}
fn openai_text_content(content: &Value) -> String {
match content {
Value::String(text) => text.clone(),
Value::Array(parts) => parts
.iter()
.filter_map(|part| part.get("text").and_then(Value::as_str))
.collect::<Vec<_>>()
.join(""),
_ => String::new(),
}
}
fn openai_finish_reason(reason: Option<&str>) -> FinishReason {
match reason {
Some("tool_calls") => FinishReason::ToolCalls,
Some("length") => FinishReason::Length,
Some("stop") | None => FinishReason::Stop,
_ => FinishReason::Error,
}
}
pub(crate) fn openai_usage(body: &Value) -> Usage {
let usage = body.get("usage");
Usage {
input_tokens: usage
.and_then(|usage| usage.get("prompt_tokens"))
.and_then(Value::as_u64)
.unwrap_or_default() as usize,
output_tokens: usage
.and_then(|usage| usage.get("completion_tokens"))
.and_then(Value::as_u64)
.unwrap_or_default() as usize,
}
}
fn openai_error_message(body: &Value) -> String {
body.get("error")
.and_then(|error| error.get("message"))
.and_then(Value::as_str)
.unwrap_or("unknown OpenAI error")
.to_string()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn splits_crlf_delimited_sse_frames() {
let mut buffer = concat!(
"data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}\r\n\r\n",
"data: [DONE]\r\n\r\n"
)
.to_string();
let first = next_sse_frame(&mut buffer).unwrap();
let second = next_sse_frame(&mut buffer).unwrap();
assert_eq!(
first,
"data: {\"choices\":[{\"delta\":{\"content\":\"hello\"}}]}"
);
assert_eq!(second, "data: [DONE]");
assert!(next_sse_frame(&mut buffer).is_none());
}
#[test]
fn parses_streamed_text_and_tool_calls() {
let mut state = OpenAiStreamState::default();
openai_process_sse_frame(
r#"data: {"choices":[{"delta":{"content":"hello "},"finish_reason":null}]}
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"id":"call_1","function":{"name":"save","arguments":"{\"note\":"}}]},"finish_reason":null}]}
data: {"choices":[{"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\"rust\"}"}}]},"finish_reason":"tool_calls"}],"usage":{"prompt_tokens":3,"completion_tokens":5}}"#,
&mut state,
)
.unwrap();
assert_eq!(state.finish_reason, FinishReason::ToolCalls);
assert_eq!(state.usage.input_tokens, 3);
assert_eq!(state.usage.output_tokens, 5);
assert!(matches!(
state.events.first(),
Some(crate::StreamEvent::TextDelta(text)) if text == "hello "
));
let parts = state.finish_parts().unwrap();
let call = state
.tool_calls
.get(&0)
.unwrap()
.clone()
.into_tool_call()
.unwrap();
assert_eq!(call.id, "call_1");
assert_eq!(call.name, "save");
assert_eq!(call.input, r#"{"note":"rust"}"#);
assert!(matches!(parts.first(), Some(Part::Text(text)) if text == "hello "));
assert!(matches!(parts.get(1), Some(Part::ToolCall(call)) if call.name == "save"));
}
#[test]
fn merges_adjacent_text_parts_in_finish() {
let mut state = OpenAiStreamState::default();
openai_process_sse_frame(
r#"data: {"choices":[{"delta":{"content":"Once"},"finish_reason":null}]}
data: {"choices":[{"delta":{"content":" upon"},"finish_reason":null}]}
data: {"choices":[{"delta":{"content":" a time"},"finish_reason":"stop"}],"usage":{"prompt_tokens":1,"completion_tokens":3}}"#,
&mut state,
)
.unwrap();
let parts = state.finish_parts().unwrap();
assert_eq!(parts, vec![Part::Text("Once upon a time".to_string())]);
}
#[test]
fn prefers_manual_api_key_over_env_lookup() {
let provider = OpenAiProvider::new("lmstudio", "LM Studio").api_key("dummy");
assert_eq!(openai_api_key(&provider).unwrap(), "dummy");
}
}