use std::time::Duration;
use crate::OpenAI;
use crate::error::OpenAIError;
use crate::runtime::sleep;
use crate::types::responses::{Response, ResponseCreateRequest};
pub async fn hedged_request(
client: &OpenAI,
request: ResponseCreateRequest,
hedge_delay: Option<Duration>,
) -> Result<Response, OpenAIError> {
let req1 = request.clone();
let req2 = request;
let responses = client.responses();
let fut1 = responses.create(req1);
let fut2 = async {
if let Some(delay) = hedge_delay {
sleep(delay).await;
}
client.responses().create(req2).await
};
tokio::pin!(fut1);
tokio::pin!(fut2);
tokio::select! {
result1 = &mut fut1 => {
match result1 {
Ok(resp) => Ok(resp),
Err(_) => fut2.await,
}
}
result2 = &mut fut2 => {
match result2 {
Ok(resp) => Ok(resp),
Err(_) => fut1.await,
}
}
}
}
pub async fn hedged_request_n(
client: &OpenAI,
request: ResponseCreateRequest,
n: usize,
hedge_delay: Option<Duration>,
) -> Result<Response, OpenAIError> {
let n = n.clamp(1, 3);
if n == 1 {
return client.responses().create(request).await;
}
if n == 2 {
return hedged_request(client, request, hedge_delay).await;
}
let delay = hedge_delay.unwrap_or(Duration::ZERO);
let req1 = request.clone();
let req2 = request.clone();
let req3 = request;
let (c1, c2, c3) = (client.clone(), client.clone(), client.clone());
let fut1 = async { c1.responses().create(req1).await };
let fut2 = async {
if !delay.is_zero() {
sleep(delay).await;
}
c2.responses().create(req2).await
};
let fut3 = async {
let stagger = delay * 2;
if !stagger.is_zero() {
sleep(stagger).await;
}
c3.responses().create(req3).await
};
tokio::pin!(fut1);
tokio::pin!(fut2);
tokio::pin!(fut3);
tokio::select! {
r = &mut fut1 => match r {
Ok(resp) => Ok(resp),
Err(_) => tokio::select! {
r = &mut fut2 => match r { Ok(resp) => Ok(resp), Err(_) => fut3.await },
r = &mut fut3 => match r { Ok(resp) => Ok(resp), Err(_) => fut2.await },
},
},
r = &mut fut2 => match r {
Ok(resp) => Ok(resp),
Err(_) => tokio::select! {
r = &mut fut1 => match r { Ok(resp) => Ok(resp), Err(_) => fut3.await },
r = &mut fut3 => match r { Ok(resp) => Ok(resp), Err(_) => fut1.await },
},
},
r = &mut fut3 => match r {
Ok(resp) => Ok(resp),
Err(_) => tokio::select! {
r = &mut fut1 => match r { Ok(resp) => Ok(resp), Err(_) => fut2.await },
r = &mut fut2 => match r { Ok(resp) => Ok(resp), Err(_) => fut1.await },
},
},
}
}
pub async fn speculative<V>(
client: &OpenAI,
step1: ResponseCreateRequest,
step2: ResponseCreateRequest,
validate_step1: V,
) -> Result<(Response, Response), OpenAIError>
where
V: FnOnce(&Response) -> bool,
{
let responses = client.responses();
let (result1, result2) = tokio::join!(responses.create(step1), responses.create(step2),);
let resp1 = result1?;
if !validate_step1(&resp1) {
return Err(OpenAIError::InvalidArgument(
"speculative: step1 validation failed, step2 result discarded".into(),
));
}
let resp2 = result2?;
Ok((resp1, resp2))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::ClientConfig;
fn response_json(id: &str, text: &str) -> String {
format!(
r#"{{
"id": "{id}",
"object": "response",
"created_at": 1677610602.0,
"model": "gpt-4o",
"output": [{{
"type": "message",
"id": "msg-1",
"role": "assistant",
"status": "completed",
"content": [{{
"type": "output_text",
"text": "{text}",
"annotations": []
}}]
}}],
"status": "completed",
"usage": {{
"input_tokens": 10,
"output_tokens": 5,
"total_tokens": 15
}}
}}"#
)
}
#[tokio::test]
async fn test_hedged_request_returns_first_success() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-hedge", "hedged!"))
.expect_at_least(1)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let request = ResponseCreateRequest::new("gpt-4o").input("Hello");
let resp = hedged_request(&client, request, None).await.unwrap();
assert_eq!(resp.output_text(), "hedged!");
}
#[tokio::test]
async fn test_hedged_request_with_delay() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-delayed", "delayed hedge"))
.expect_at_least(1)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let request = ResponseCreateRequest::new("gpt-4o").input("Hello");
let resp = hedged_request(&client, request, Some(Duration::from_millis(50)))
.await
.unwrap();
assert_eq!(resp.output_text(), "delayed hedge");
}
#[tokio::test]
async fn test_hedged_request_fallback_on_first_failure() {
let mut server = mockito::Server::new_async().await;
let _mock_fail = server
.mock("POST", "/responses")
.with_status(500)
.with_body(
r#"{"error":{"message":"fail","type":"server_error","param":null,"code":null}}"#,
)
.create_async()
.await;
let _mock_ok = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-fallback", "recovered"))
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let request = ResponseCreateRequest::new("gpt-4o").input("Hello");
let resp = hedged_request(&client, request, None).await.unwrap();
assert_eq!(resp.output_text(), "recovered");
}
#[tokio::test]
async fn test_hedged_request_n_single() {
let mut server = mockito::Server::new_async().await;
let mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-n1", "single"))
.expect(1)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let request = ResponseCreateRequest::new("gpt-4o").input("Hello");
let resp = hedged_request_n(&client, request, 1, None).await.unwrap();
assert_eq!(resp.output_text(), "single");
mock.assert_async().await;
}
#[tokio::test]
async fn test_hedged_request_n_triple() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-n3", "triple"))
.expect_at_least(1)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let request = ResponseCreateRequest::new("gpt-4o").input("Hello");
let resp = hedged_request_n(&client, request, 3, None).await.unwrap();
assert_eq!(resp.output_text(), "triple");
}
#[tokio::test]
async fn test_hedged_request_n_capped_at_3() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-cap", "capped"))
.expect_at_most(3)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let request = ResponseCreateRequest::new("gpt-4o").input("Hello");
let resp = hedged_request_n(&client, request, 10, None).await.unwrap();
assert_eq!(resp.output_text(), "capped");
}
#[tokio::test]
async fn test_speculative_both_succeed_validation_passes() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-spec", "safe content"))
.expect_at_least(2)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let step1 = ResponseCreateRequest::new("gpt-4o").input("moderate this");
let step2 = ResponseCreateRequest::new("gpt-4o").input("generate answer");
let (resp1, resp2) =
speculative(&client, step1, step2, |r| r.output_text().contains("safe"))
.await
.unwrap();
assert_eq!(resp1.id, "resp-spec");
assert_eq!(resp2.id, "resp-spec");
}
#[tokio::test]
async fn test_speculative_validation_fails() {
let mut server = mockito::Server::new_async().await;
let _mock = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-spec-fail", "unsafe content"))
.expect_at_least(2)
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let step1 = ResponseCreateRequest::new("gpt-4o").input("moderate this");
let step2 = ResponseCreateRequest::new("gpt-4o").input("generate answer");
let err = speculative(&client, step1, step2, |r| {
r.output_text().contains("definitely_not_here")
})
.await
.unwrap_err();
assert!(
matches!(err, OpenAIError::InvalidArgument(_)),
"expected InvalidArgument, got: {err:?}"
);
}
#[tokio::test]
async fn test_speculative_step1_api_error() {
let mut server = mockito::Server::new_async().await;
let _mock_fail = server
.mock("POST", "/responses")
.with_status(500)
.with_body(
r#"{"error":{"message":"boom","type":"server_error","param":null,"code":null}}"#,
)
.create_async()
.await;
let _mock_ok = server
.mock("POST", "/responses")
.with_status(200)
.with_header("content-type", "application/json")
.with_body(response_json("resp-ok", "ok"))
.create_async()
.await;
let client = OpenAI::with_config(
ClientConfig::new("sk-test")
.base_url(server.url())
.max_retries(0),
);
let step1 = ResponseCreateRequest::new("gpt-4o").input("moderate");
let step2 = ResponseCreateRequest::new("gpt-4o").input("generate");
let err = speculative(&client, step1, step2, |_| true)
.await
.unwrap_err();
assert!(
matches!(err, OpenAIError::ApiError { .. }),
"expected ApiError, got: {err:?}"
);
}
}