use serde_json::Value;
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio_stream::wrappers::UnboundedReceiverStream;
use tokio_stream::{Stream, StreamExt};
use crate::error::ClawError;
use crate::messages::Message;
use crate::options::ClaudeAgentOptions;
use crate::transport::{SubprocessCLITransport, Transport};
pub struct QueryStream<S> {
inner: S,
#[allow(dead_code)]
transport: SubprocessCLITransport,
}
impl<S> QueryStream<S>
where
S: Stream<Item = Result<Message, ClawError>>,
{
fn new(transport: SubprocessCLITransport, inner: S) -> Self {
Self { inner, transport }
}
}
impl<S> Stream for QueryStream<S>
where
S: Stream<Item = Result<Message, ClawError>> + Unpin,
{
type Item = Result<Message, ClawError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
Pin::new(&mut self.inner).poll_next(cx)
}
}
pub async fn query(
prompt: impl Into<String>,
options: Option<ClaudeAgentOptions>,
) -> Result<impl Stream<Item = Result<Message, ClawError>>, ClawError> {
let prompt = prompt.into();
let args = if let Some(opts) = options {
opts.to_cli_args(&prompt)
} else {
vec![
"--output-format".to_string(),
"stream-json".to_string(),
"--verbose".to_string(),
"--setting-sources".to_string(),
String::new(),
"-p".to_string(),
prompt,
]
};
let mut transport = SubprocessCLITransport::new(None, args);
transport.connect().await?;
transport.end_input().await?;
let rx = transport.messages();
let stream = UnboundedReceiverStream::new(rx).map(|result| {
result.and_then(|value| {
let raw = value.to_string();
serde_json::from_value::<Message>(value).map_err(|e| ClawError::MessageParse {
reason: e.to_string(),
raw,
})
})
});
Ok(QueryStream::new(transport, stream))
}
pub async fn query_with_messages(
messages: impl Stream<Item = Value> + Unpin,
options: Option<ClaudeAgentOptions>,
) -> Result<impl Stream<Item = Result<Message, ClawError>>, ClawError> {
let args = build_stream_args(options.as_ref());
let mut transport =
SubprocessCLITransport::new(options.as_ref().and_then(|o| o.cli_path.clone()), args);
if let Some(cwd) = options.as_ref().and_then(|o| o.cwd.as_ref()) {
transport.set_cwd(cwd.clone());
}
if let Some(env) = options.as_ref().map(|o| &o.env).filter(|e| !e.is_empty()) {
transport.set_env(env.clone());
}
if let Some(cb) = options.as_ref().and_then(|o| o.stderr_callback.as_ref()) {
let cb_clone = cb.clone();
transport.set_stderr_callback(move |line| cb_clone(line));
}
if let Some(size) = options.as_ref().and_then(|o| o.max_buffer_size) {
transport.set_max_buffer_size(size);
}
transport.connect().await?;
let mut messages = messages;
while let Some(msg) = messages.next().await {
let mut line = serde_json::to_string(&msg).map_err(ClawError::JsonDecode)?;
line.push('\n');
transport.write(line.as_bytes()).await?;
}
transport.end_input().await?;
let rx = transport.messages();
let stream = UnboundedReceiverStream::new(rx).map(|result| {
result.and_then(|value| {
let raw = value.to_string();
serde_json::from_value::<Message>(value).map_err(|e| ClawError::MessageParse {
reason: e.to_string(),
raw,
})
})
});
Ok(QueryStream::new(transport, stream))
}
fn build_stream_args(options: Option<&ClaudeAgentOptions>) -> Vec<String> {
let mut args = vec![
"--output-format".to_string(),
"stream-json".to_string(),
"--verbose".to_string(),
"--input-format".to_string(),
"stream-json".to_string(),
];
if let Some(opts) = options {
if let Some(max_turns) = opts.max_turns {
args.push("--max-turns".to_string());
args.push(max_turns.to_string());
}
if let Some(model) = &opts.model {
args.push("--model".to_string());
args.push(model.clone());
}
if let Some(mode) = &opts.permission_mode {
args.push("--permission-mode".to_string());
args.push(mode.to_cli_arg().to_string());
}
if let Some(sys_prompt) = &opts.system_prompt {
match sys_prompt {
crate::options::SystemPrompt::Custom(text) => {
args.push("--system-prompt".to_string());
args.push(text.clone());
}
crate::options::SystemPrompt::Preset { preset } => {
args.push("--system-prompt-preset".to_string());
args.push(preset.clone());
}
}
}
if !opts.allowed_tools.is_empty() {
args.push("--allowed-tools".to_string());
args.push(opts.allowed_tools.join(","));
}
for beta in &opts.betas {
args.push("--beta".to_string());
args.push(beta.clone());
}
if let Some(fallback) = &opts.fallback_model {
args.push("--fallback-model".to_string());
args.push(fallback.clone());
}
if let Some(user) = &opts.user {
args.push("--user".to_string());
args.push(user.clone());
}
match &opts.setting_sources {
Some(sources) => {
args.push("--setting-sources".to_string());
args.push(sources.join(","));
}
None => {
args.push("--setting-sources".to_string());
args.push(String::new());
}
}
for (key, value) in &opts.extra_args {
let flag = if key.starts_with("--") {
key.clone()
} else {
format!("--{}", key)
};
args.push(flag);
if let Some(val) = value {
args.push(val.clone());
}
}
} else {
args.push("--setting-sources".to_string());
args.push(String::new());
}
args
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_query_stream_is_send() {
fn assert_send<T: Send>() {}
use tokio_stream::wrappers::UnboundedReceiverStream;
type ConcreteStream = UnboundedReceiverStream<Result<Message, ClawError>>;
assert_send::<QueryStream<ConcreteStream>>();
}
#[test]
fn test_query_stream_is_unpin() {
fn assert_unpin<T: Unpin>() {}
use tokio_stream::wrappers::UnboundedReceiverStream;
type ConcreteStream = UnboundedReceiverStream<Result<Message, ClawError>>;
assert_unpin::<QueryStream<ConcreteStream>>();
}
#[test]
fn test_query_accepts_string() {
fn _assert_compiles() {
async fn _test() -> Result<(), ClawError> {
let _ = query("test".to_string(), None).await?;
Ok(())
}
}
}
#[test]
fn test_query_accepts_str() {
fn _assert_compiles() {
async fn _test() -> Result<(), ClawError> {
let _ = query("test", None).await?;
Ok(())
}
}
}
}