use std::collections::BTreeMap;
use std::pin::Pin;
use async_stream::try_stream;
use futures_util::{Stream, StreamExt};
use reqwest::Method;
use serde::{Deserialize, Serialize};
use serde_json::Value;
use crate::config::ServiceBase;
use crate::{
ChatMessage, ChatStream, Client, CreateChatCompletionArgs, CreateChatCompletionResponse,
ModelId, RekaError, Result,
};
#[derive(Clone)]
pub struct ResearchClient {
client: Client,
}
impl ResearchClient {
pub(crate) fn new(client: Client) -> Self {
Self { client }
}
pub async fn create(&self, args: &CreateResearchArgs) -> Result<CreateChatCompletionResponse> {
args.validate(false)?;
self.client
.request(ServiceBase::Chat, Method::POST, "/chat/completions")
.json(&ResearchBody::standard(args))
.send_json()
.await
}
pub async fn stream(&self, args: &CreateResearchArgs) -> Result<ChatStream> {
args.validate(true)?;
let events = self
.client
.request(ServiceBase::Chat, Method::POST, "/chat/completions")
.accept("text/event-stream")
.json(&ResearchBody::streaming(args))
.send_sse()
.await?;
let stream = try_stream! {
let mut events = events;
while let Some(event) = events.next().await {
let event = event?;
if event.data == "[DONE]" {
break;
}
let chunk = serde_json::from_str::<crate::ChatStreamEvent>(&event.data)
.map_err(|source| RekaError::decode("/chat/completions", event.data, source))?;
yield chunk;
}
};
Ok(ChatStream {
inner: Box::pin(stream)
as Pin<Box<dyn Stream<Item = Result<crate::ChatStreamEvent>> + Send>>,
})
}
}
#[derive(Debug, Clone, PartialEq)]
pub struct CreateResearchArgs {
pub chat: CreateChatCompletionArgs,
pub response_format: Option<Value>,
pub research: ResearchOptions,
}
impl CreateResearchArgs {
pub fn new(model: ModelId, messages: Vec<ChatMessage>) -> Self {
Self {
chat: CreateChatCompletionArgs::new(model, messages),
response_format: None,
research: ResearchOptions::default(),
}
}
pub fn with_model(mut self, model: ModelId) -> Self {
self.chat.model = model;
self
}
pub fn with_temperature(mut self, temperature: f32) -> Self {
self.chat.temperature = Some(temperature);
self
}
pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
self.chat.max_tokens = Some(max_tokens);
self
}
pub fn with_stop<I, S>(mut self, stop: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.chat.stop = Some(stop.into_iter().map(Into::into).collect());
self
}
pub fn with_response_format(mut self, response_format: Value) -> Self {
self.response_format = Some(response_format);
self
}
pub fn with_web_search(mut self, web_search: WebSearchOptions) -> Self {
self.research.web_search = Some(web_search);
self
}
pub fn with_parallel_thinking(mut self, parallel_thinking: ParallelThinkingOptions) -> Self {
self.research.parallel_thinking = Some(parallel_thinking);
self
}
pub fn insert_extra(mut self, key: impl Into<String>, value: Value) -> Self {
self.chat.extra.insert(key.into(), value);
self
}
fn validate(&self, streaming: bool) -> Result<()> {
if let Some(web_search) = &self.research.web_search
&& web_search.allowed_domains.is_some()
&& web_search.blocked_domains.is_some()
{
return Err(RekaError::InvalidRequest(
"research.web_search.allowed_domains and blocked_domains cannot both be set"
.to_string(),
));
}
if streaming
&& let Some(parallel_thinking) = &self.research.parallel_thinking
&& parallel_thinking.mode != "none"
{
return Err(RekaError::InvalidRequest(
"research.parallel_thinking is not supported when streaming is enabled".to_string(),
));
}
Ok(())
}
}
#[derive(Serialize)]
struct ResearchBody<'a> {
#[serde(flatten)]
chat: &'a CreateChatCompletionArgs,
#[serde(skip_serializing_if = "Option::is_none")]
response_format: Option<&'a Value>,
#[serde(skip_serializing_if = "ResearchOptions::is_empty")]
research: &'a ResearchOptions,
#[serde(skip_serializing_if = "is_false")]
stream: bool,
}
impl<'a> ResearchBody<'a> {
fn standard(args: &'a CreateResearchArgs) -> Self {
Self {
chat: &args.chat,
response_format: args.response_format.as_ref(),
research: &args.research,
stream: false,
}
}
fn streaming(args: &'a CreateResearchArgs) -> Self {
Self {
stream: true,
..Self::standard(args)
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct ResearchOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub web_search: Option<WebSearchOptions>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parallel_thinking: Option<ParallelThinkingOptions>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl ResearchOptions {
fn is_empty(&self) -> bool {
self.web_search.is_none() && self.parallel_thinking.is_none() && self.extra.is_empty()
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct WebSearchOptions {
#[serde(skip_serializing_if = "Option::is_none")]
pub max_uses: Option<u32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub allowed_domains: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub blocked_domains: Option<Vec<String>>,
#[serde(skip_serializing_if = "Option::is_none")]
pub user_location: Option<UserLocation>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl WebSearchOptions {
pub fn new() -> Self {
Self::default()
}
pub fn with_max_uses(mut self, max_uses: u32) -> Self {
self.max_uses = Some(max_uses);
self
}
pub fn with_allowed_domains<I, S>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.allowed_domains = Some(domains.into_iter().map(Into::into).collect());
self
}
pub fn with_blocked_domains<I, S>(mut self, domains: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
self.blocked_domains = Some(domains.into_iter().map(Into::into).collect());
self
}
pub fn with_user_location(mut self, user_location: UserLocation) -> Self {
self.user_location = Some(user_location);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
pub struct UserLocation {
#[serde(skip_serializing_if = "Option::is_none")]
pub city: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub region: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub country: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub timezone: Option<String>,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl UserLocation {
pub fn new() -> Self {
Self::default()
}
pub fn with_city(mut self, city: impl Into<String>) -> Self {
self.city = Some(city.into());
self
}
pub fn with_region(mut self, region: impl Into<String>) -> Self {
self.region = Some(region.into());
self
}
pub fn with_country(mut self, country: impl Into<String>) -> Self {
self.country = Some(country.into());
self
}
pub fn with_timezone(mut self, timezone: impl Into<String>) -> Self {
self.timezone = Some(timezone.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct ParallelThinkingOptions {
pub mode: String,
#[serde(default, flatten)]
pub extra: BTreeMap<String, Value>,
}
impl ParallelThinkingOptions {
pub fn new(mode: impl Into<String>) -> Self {
Self {
mode: mode.into(),
extra: BTreeMap::new(),
}
}
pub fn low() -> Self {
Self::new("low")
}
pub fn high() -> Self {
Self::new("high")
}
pub fn none() -> Self {
Self::new("none")
}
}
fn is_false(value: &bool) -> bool {
!*value
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::{
CreateResearchArgs, ParallelThinkingOptions, ResearchBody, UserLocation, WebSearchOptions,
};
use crate::{ChatMessage, ModelId, RekaError};
#[test]
fn research_request_injects_extra_body_fields() {
let args = CreateResearchArgs::new(
ModelId::flash_research(),
vec![ChatMessage::user("Who won the UEFA Nations League 2025?")],
)
.with_web_search(
WebSearchOptions::new()
.with_max_uses(3)
.with_allowed_domains(["uefa.com", "espn.com"])
.with_user_location(
UserLocation::new()
.with_city("San Francisco")
.with_region("CA")
.with_country("US")
.with_timezone("America/Los_Angeles"),
),
)
.with_response_format(json!({
"type": "json_schema",
"json_schema": {
"name": "answer_format",
"schema": {
"type": "object",
"properties": {
"winner": { "type": "string" }
},
"required": ["winner"]
}
}
}));
let json =
serde_json::to_value(ResearchBody::standard(&args)).expect("payload should serialize");
assert_eq!(json["model"], "reka-flash-research");
assert_eq!(json["research"]["web_search"]["max_uses"], 3);
assert_eq!(
json["research"]["web_search"]["allowed_domains"][0],
"uefa.com"
);
assert_eq!(json["response_format"]["type"], "json_schema");
assert_eq!(json.get("stream"), None);
}
#[test]
fn rejects_conflicting_domain_filters() {
let error =
CreateResearchArgs::new(ModelId::flash_research(), vec![ChatMessage::user("test")])
.with_web_search(
WebSearchOptions::new()
.with_allowed_domains(["uefa.com"])
.with_blocked_domains(["espn.com"]),
)
.validate(false)
.expect_err("request should be rejected");
assert!(matches!(error, RekaError::InvalidRequest(_)));
}
#[test]
fn rejects_parallel_thinking_while_streaming() {
let error =
CreateResearchArgs::new(ModelId::flash_research(), vec![ChatMessage::user("test")])
.with_parallel_thinking(ParallelThinkingOptions::high())
.validate(true)
.expect_err("request should be rejected");
assert!(matches!(error, RekaError::InvalidRequest(_)));
}
}