use opencode_sdk::error::OpencodeError;
use opencode_sdk::server::{ManagedServer, ServerOptions};
use opencode_sdk::types::event::Event;
use opencode_sdk::types::message::{Message, Part, PromptRequest};
use opencode_sdk::types::permission::{
PermissionAction, PermissionReply, PermissionReplyRequest, PermissionRule, Ruleset,
};
use opencode_sdk::types::session::CreateSessionRequest;
use opencode_sdk::Client;
use serde::{Deserialize, Deserializer, Serialize};
use std::path::PathBuf;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use thiserror::Error;
const RESPONSE_FILE_NAME: &str = "response.json";
#[derive(Debug, Error)]
pub enum WebResearchError {
#[error("opencode sdk error: {0}")]
Opencode(#[from] OpencodeError),
#[error("io error: {0}")]
Io(#[from] std::io::Error),
#[error("json error: {0}")]
Json(#[from] serde_json::Error),
#[error("invalid response json: {0}")]
InvalidResponseJson(String),
#[error("all {attempts} attempts failed, last error: {last_error}")]
RetriesExhausted { attempts: usize, last_error: String },
}
pub type Result<T> = std::result::Result<T, WebResearchError>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResearchRequest {
pub prompt: String,
pub opencode_server_hostname: String,
#[serde(deserialize_with = "deserialize_u16_from_string_or_number")]
pub opencode_server_port: u16,
pub llm_provider: String,
pub llm_mode_name: String,
#[serde(default)]
pub tools: Vec<String>,
pub output_directory: PathBuf,
#[serde(default)]
pub working_directory: Option<PathBuf>,
#[serde(default = "default_timeout_secs")]
pub timeout_secs: u64,
#[serde(default = "default_server_startup_timeout_ms")]
pub server_startup_timeout_ms: u64,
#[serde(default = "default_max_attempts")]
pub max_attempts: usize,
#[serde(default = "default_shutdown_server_when_done")]
pub shutdown_server_when_done: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResearchResponse {
pub response_path: PathBuf,
pub attempt: usize,
pub started_server_internally: bool,
pub payload: ResearchPayload,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ResearchPayload {
pub prompt: String,
pub answer: String,
pub session_id: String,
pub attempt: usize,
pub model: ModelDescriptor,
pub tools: Vec<String>,
pub message_count: usize,
pub generated_at_unix: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelDescriptor {
pub provider: String,
pub model: String,
}
pub async fn run_research(request: ResearchRequest) -> Result<ResearchResponse> {
tokio::fs::create_dir_all(&request.output_directory).await?;
let mut managed_server: Option<ManagedServer> = None;
let mut last_error = String::from("unknown error");
let mut attempts = request.max_attempts.max(1);
let mut maybe_response: Option<ResearchResponse> = None;
for attempt in 1..=attempts {
let attempt_result = run_single_attempt(&request, attempt, &mut managed_server).await;
match attempt_result {
Ok(response) => {
maybe_response = Some(response);
break;
}
Err(err) => {
last_error = err.to_string();
}
}
}
if request.shutdown_server_when_done {
if let Some(server) = managed_server.take() {
let _ = server.stop().await;
}
}
if let Some(response) = maybe_response {
return Ok(response);
}
attempts = attempts.max(1);
Err(WebResearchError::RetriesExhausted {
attempts,
last_error,
})
}
async fn run_single_attempt(
request: &ResearchRequest,
attempt: usize,
managed_server: &mut Option<ManagedServer>,
) -> Result<ResearchResponse> {
let (client, started_server_internally) = ensure_client(request, managed_server).await?;
let session_title = format!("opencode_webresearch attempt {}", attempt);
let create_request = CreateSessionRequest {
parent_id: None,
title: Some(session_title),
permission: Some(build_allow_all_ruleset()),
directory: Some(resolve_working_directory(request)?.display().to_string()),
};
let session = client.sessions().create(&create_request).await?;
let workflow_result = run_session_workflow(&client, request, &session.id, attempt).await;
let delete_result = client.sessions().delete(&session.id).await;
if let Err(err) = delete_result {
if workflow_result.is_ok() {
return Err(err.into());
}
}
let mut response = workflow_result?;
response.started_server_internally = started_server_internally;
Ok(response)
}
async fn run_session_workflow(
client: &Client,
request: &ResearchRequest,
session_id: &str,
attempt: usize,
) -> Result<ResearchResponse> {
let mut subscription = client.subscribe_session(session_id).await?;
let final_prompt = build_prompt_with_tools(&request.prompt, &request.tools);
let prompt_request = PromptRequest::text(final_prompt).with_model(
request.llm_provider.clone(),
request.llm_mode_name.clone(),
);
client
.messages()
.prompt_async(session_id, &prompt_request)
.await?;
let timeout = Duration::from_secs(request.timeout_secs.max(1));
let streamed_text = collect_text_until_idle(client, &mut subscription, timeout).await?;
let messages = client.messages().list(session_id).await?;
let answer = if streamed_text.trim().is_empty() {
extract_latest_assistant_text(&messages)
} else {
streamed_text.trim().to_string()
};
if answer.trim().is_empty() {
return Err(WebResearchError::InvalidResponseJson(
"assistant answer is empty".to_string(),
));
}
let payload = ResearchPayload {
prompt: request.prompt.clone(),
answer,
session_id: session_id.to_string(),
attempt,
model: ModelDescriptor {
provider: request.llm_provider.clone(),
model: request.llm_mode_name.clone(),
},
tools: request.tools.clone(),
message_count: messages.len(),
generated_at_unix: unix_now(),
};
let json_text = serde_json::to_string_pretty(&payload)?;
validate_response_json(&json_text)?;
let response_path = request.output_directory.join(RESPONSE_FILE_NAME);
tokio::fs::write(&response_path, json_text).await?;
Ok(ResearchResponse {
response_path,
attempt,
started_server_internally: false,
payload,
})
}
async fn collect_text_until_idle(
client: &Client,
subscription: &mut opencode_sdk::sse::SseSubscription,
timeout: Duration,
) -> Result<String> {
let deadline = tokio::time::Instant::now() + timeout;
let mut output = String::new();
loop {
let now = tokio::time::Instant::now();
if now >= deadline {
return Err(WebResearchError::Opencode(OpencodeError::ServerTimeout {
timeout_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
}));
}
let remaining = deadline.saturating_duration_since(now);
let next_event = tokio::time::timeout(remaining, subscription.recv()).await;
let event = match next_event {
Ok(Some(value)) => value,
Ok(None) => return Err(WebResearchError::Opencode(OpencodeError::StreamClosed)),
Err(_) => {
return Err(WebResearchError::Opencode(OpencodeError::ServerTimeout {
timeout_ms: u64::try_from(timeout.as_millis()).unwrap_or(u64::MAX),
}))
}
};
match event {
Event::MessagePartUpdated { properties } => {
if let Some(delta) = properties.delta.as_deref() {
let is_text_part = matches!(properties.part.as_ref(), Some(Part::Text { .. }));
if is_text_part || properties.part.is_none() {
output.push_str(delta);
}
}
}
Event::PermissionAsked { properties } => {
let reply = PermissionReplyRequest {
reply: PermissionReply::Always,
message: Some("Auto-approved by opencode_webresearch".to_string()),
};
client
.permissions()
.reply(&properties.request.id, &reply)
.await?;
}
Event::SessionStatus { properties } => {
if is_status_idle(&properties) {
break;
}
}
Event::SessionIdle { .. } => break,
Event::SessionError { properties } => {
return Err(WebResearchError::InvalidResponseJson(format!(
"session error: {:?}",
properties.error
)));
}
_ => {}
}
}
Ok(output)
}
async fn ensure_client(
request: &ResearchRequest,
managed_server: &mut Option<ManagedServer>,
) -> Result<(Client, bool)> {
if let Some(existing_server) = managed_server.as_mut() {
if existing_server.is_running() {
let client = build_client(request, existing_server.url().to_string())?;
if client.misc().health().await.is_ok() {
return Ok((client, true));
}
}
}
let external_client = build_client(request, external_base_url(request))?;
if external_client.misc().health().await.is_ok() {
return Ok((external_client, false));
}
if let Some(server) = managed_server.take() {
let _ = server.stop().await;
}
let working_directory = resolve_working_directory(request)?;
let requested_options = ServerOptions::new()
.hostname(request.opencode_server_hostname.clone())
.port(request.opencode_server_port)
.directory(working_directory.clone())
.startup_timeout_ms(request.server_startup_timeout_ms);
let new_server = match ManagedServer::start(requested_options).await {
Ok(server) => server,
Err(first_error) if should_retry_managed_start_with_random_port(&first_error) => {
let fallback_options = ServerOptions::new()
.hostname(request.opencode_server_hostname.clone())
.directory(working_directory)
.startup_timeout_ms(request.server_startup_timeout_ms);
ManagedServer::start(fallback_options).await?
}
Err(first_error) => return Err(first_error.into()),
};
let managed_client = build_client(request, new_server.url().to_string())?;
managed_client.misc().health().await?;
*managed_server = Some(new_server);
Ok((managed_client, true))
}
fn build_client(request: &ResearchRequest, base_url: String) -> Result<Client> {
let working_directory = resolve_working_directory(request)?;
let client = Client::builder()
.base_url(base_url)
.directory(working_directory.display().to_string())
.timeout_secs(request.timeout_secs.max(1))
.build()?;
Ok(client)
}
fn resolve_working_directory(request: &ResearchRequest) -> Result<PathBuf> {
if let Some(path) = &request.working_directory {
return Ok(path.clone());
}
Ok(std::env::current_dir()?)
}
fn build_prompt_with_tools(base_prompt: &str, tools: &[String]) -> String {
if tools.is_empty() {
return base_prompt.to_string();
}
let tool_hint = tools.join(", ");
format!(
"{base_prompt}\n\nUse these MCP tools when available: {tool_hint}. \
Return a factual answer with concrete references."
)
}
fn external_base_url(request: &ResearchRequest) -> String {
let connect_host = normalize_connect_host(&request.opencode_server_hostname);
format!("http://{}:{}", connect_host, request.opencode_server_port)
}
fn normalize_connect_host(host: &str) -> String {
match host.trim() {
"0.0.0.0" => "127.0.0.1".to_string(),
"::" => "::1".to_string(),
value => value.to_string(),
}
}
fn should_retry_managed_start_with_random_port(error: &OpencodeError) -> bool {
matches!(
error,
OpencodeError::SpawnServer { .. } | OpencodeError::ServerTimeout { .. }
)
}
fn build_allow_all_ruleset() -> Ruleset {
vec![PermissionRule {
permission: "*".to_string(),
pattern: "*".to_string(),
action: PermissionAction::Allow,
}]
}
fn is_status_idle(properties: &serde_json::Value) -> bool {
if properties
.get("status")
.and_then(|value| value.get("type"))
.and_then(serde_json::Value::as_str)
== Some("idle")
{
return true;
}
properties
.get("type")
.and_then(serde_json::Value::as_str)
== Some("idle")
}
fn extract_latest_assistant_text(messages: &[Message]) -> String {
for message in messages.iter().rev() {
if message.role() != "assistant" {
continue;
}
let mut buffer = String::new();
for part in &message.parts {
if let Part::Text { text, .. } = part {
buffer.push_str(text);
}
}
if !buffer.trim().is_empty() {
return buffer.trim().to_string();
}
}
String::new()
}
fn validate_response_json(json_text: &str) -> Result<()> {
let value: serde_json::Value = serde_json::from_str(json_text)?;
let answer = value.get("answer").and_then(serde_json::Value::as_str);
match answer {
Some(content) if !content.trim().is_empty() => Ok(()),
_ => Err(WebResearchError::InvalidResponseJson(
"missing non-empty answer field".to_string(),
)),
}
}
fn unix_now() -> u64 {
match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(duration) => duration.as_secs(),
Err(_) => 0,
}
}
fn deserialize_u16_from_string_or_number<'de, D>(deserializer: D) -> std::result::Result<u16, D::Error>
where
D: Deserializer<'de>,
{
#[derive(Deserialize)]
#[serde(untagged)]
enum PortValue {
Number(u16),
String(String),
}
match PortValue::deserialize(deserializer)? {
PortValue::Number(value) => Ok(value),
PortValue::String(value) => value.parse::<u16>().map_err(serde::de::Error::custom),
}
}
fn default_timeout_secs() -> u64 {
180
}
fn default_max_attempts() -> usize {
3
}
fn default_server_startup_timeout_ms() -> u64 {
15_000
}
fn default_shutdown_server_when_done() -> bool {
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normalize_wildcard_hosts() {
assert_eq!(normalize_connect_host("0.0.0.0"), "127.0.0.1");
assert_eq!(normalize_connect_host("::"), "::1");
assert_eq!(normalize_connect_host("localhost"), "localhost");
}
#[test]
fn managed_start_retryable_errors_are_detected() {
let spawn_error = OpencodeError::SpawnServer {
message: "boom".to_string(),
};
let timeout_error = OpencodeError::ServerTimeout { timeout_ms: 1000 };
let network_error = OpencodeError::Network("no route".to_string());
assert!(should_retry_managed_start_with_random_port(&spawn_error));
assert!(should_retry_managed_start_with_random_port(&timeout_error));
assert!(!should_retry_managed_start_with_random_port(&network_error));
}
#[test]
fn prompt_builder_handles_tools() {
let prompt = build_prompt_with_tools("Research topic", &["searxng".into(), "webfetch".into()]);
assert!(prompt.contains("searxng"));
assert!(prompt.contains("webfetch"));
}
#[test]
fn prompt_builder_without_tools_is_passthrough() {
let prompt = build_prompt_with_tools("Only base prompt", &[]);
assert_eq!(prompt, "Only base prompt");
}
#[test]
fn status_idle_detector_handles_known_shapes() {
let shape_a = serde_json::json!({ "status": { "type": "idle" } });
let shape_b = serde_json::json!({ "type": "idle" });
let shape_c = serde_json::json!({ "type": "running" });
assert!(is_status_idle(&shape_a));
assert!(is_status_idle(&shape_b));
assert!(!is_status_idle(&shape_c));
}
#[test]
fn response_validator_rejects_empty_answer() {
let invalid = r#"{"answer":" "}"#;
let err = validate_response_json(invalid);
assert!(err.is_err());
}
#[test]
fn response_validator_accepts_non_empty_answer() {
let valid = r#"{"answer":"content"}"#;
let result = validate_response_json(valid);
assert!(result.is_ok());
}
#[test]
fn allow_all_ruleset_is_non_empty() {
let ruleset = build_allow_all_ruleset();
assert_eq!(ruleset.len(), 1);
assert_eq!(ruleset[0].permission, "*");
assert_eq!(ruleset[0].pattern, "*");
}
#[test]
fn deserialize_port_from_string_and_number() {
#[derive(Deserialize)]
struct Wrapper {
#[serde(deserialize_with = "deserialize_u16_from_string_or_number")]
port: u16,
}
let numeric: Wrapper = serde_json::from_str(r#"{"port":7777}"#).unwrap();
let stringy: Wrapper = serde_json::from_str(r#"{"port":"7777"}"#).unwrap();
assert_eq!(numeric.port, 7777);
assert_eq!(stringy.port, 7777);
}
}