Skip to main content

claude_wrapper/
streaming.rs

1#[cfg(feature = "json")]
2use std::time::Duration;
3
4#[cfg(all(feature = "json", feature = "async"))]
5use tokio::io::{AsyncBufReadExt, AsyncReadExt, BufReader};
6#[cfg(all(feature = "json", feature = "async"))]
7use tokio::process::{ChildStderr, Command};
8#[cfg(feature = "json")]
9use tracing::{debug, warn};
10
11#[cfg(feature = "json")]
12use crate::Claude;
13#[cfg(feature = "json")]
14use crate::error::{Error, Result};
15#[cfg(feature = "json")]
16use crate::exec::CommandOutput;
17
18/// A single line from `--output-format stream-json` output.
19///
20/// Each line is an NDJSON object. The structure varies by message type,
21/// so we provide the raw JSON value and convenience accessors.
22#[cfg(feature = "json")]
23#[derive(Debug, Clone, serde::Deserialize, serde::Serialize)]
24pub struct StreamEvent {
25    /// The raw JSON object for this event.
26    #[serde(flatten)]
27    pub data: serde_json::Value,
28}
29
30#[cfg(feature = "json")]
31impl StreamEvent {
32    /// Get the event type, if present.
33    pub fn event_type(&self) -> Option<&str> {
34        self.data.get("type").and_then(|v| v.as_str())
35    }
36
37    /// Get the message role, if present.
38    pub fn role(&self) -> Option<&str> {
39        self.data.get("role").and_then(|v| v.as_str())
40    }
41
42    /// Check if this is the final result message.
43    pub fn is_result(&self) -> bool {
44        self.event_type() == Some("result")
45    }
46
47    /// Extract the result text from a result event.
48    pub fn result_text(&self) -> Option<&str> {
49        self.data.get("result").and_then(|v| v.as_str())
50    }
51
52    /// Get the session ID if present.
53    pub fn session_id(&self) -> Option<&str> {
54        self.data.get("session_id").and_then(|v| v.as_str())
55    }
56
57    /// Get the cost in USD if present (usually on result events).
58    ///
59    /// Prefers `total_cost_usd` (the CLI's primary key) and falls back
60    /// to the legacy `cost_usd` alias.
61    pub fn cost_usd(&self) -> Option<f64> {
62        self.data
63            .get("total_cost_usd")
64            .or_else(|| self.data.get("cost_usd"))
65            .and_then(|v| v.as_f64())
66    }
67
68    /// Decode a partial-message event into a typed view.
69    ///
70    /// Returns `Some` when the event is one of the content-block lifecycle
71    /// events surfaced by [`QueryCommand::include_partial_messages`] -- start,
72    /// delta, or stop. Returns `None` for any other event (system, assistant,
73    /// result, message-level stream events, etc).
74    ///
75    /// The CLI wraps each raw streaming event as
76    /// `{"type":"stream_event","event":{...}}`; this accessor unwraps that
77    /// envelope. Unknown block types and unknown delta types fall through to
78    /// [`BlockType::Other`] / [`BlockDelta::Other`] rather than erroring, so
79    /// future content-block kinds remain accessible (just untyped).
80    ///
81    /// # Example
82    ///
83    /// Pull incremental thinking text out of a partial-message event:
84    ///
85    /// ```
86    /// use claude_wrapper::streaming::{BlockDelta, PartialMessageEvent, StreamEvent};
87    /// use serde_json::json;
88    ///
89    /// let event: StreamEvent = serde_json::from_value(json!({
90    ///     "type": "stream_event",
91    ///     "event": {
92    ///         "type": "content_block_delta",
93    ///         "index": 0,
94    ///         "delta": { "type": "thinking_delta", "thinking": "Let me think..." }
95    ///     },
96    ///     "session_id": "abc"
97    /// })).unwrap();
98    ///
99    /// match event.partial_message() {
100    ///     Some(PartialMessageEvent::BlockDelta { delta: BlockDelta::Thinking(t), .. }) => {
101    ///         assert_eq!(t, "Let me think...");
102    ///     }
103    ///     _ => unreachable!(),
104    /// }
105    /// ```
106    ///
107    /// [`QueryCommand::include_partial_messages`]: crate::QueryCommand::include_partial_messages
108    pub fn partial_message(&self) -> Option<PartialMessageEvent> {
109        let event = if self.event_type() == Some("stream_event") {
110            self.data.get("event")?
111        } else {
112            &self.data
113        };
114
115        let inner_type = event.get("type")?.as_str()?;
116        let index = event.get("index").and_then(serde_json::Value::as_u64)?;
117        let index = u32::try_from(index).ok()?;
118
119        match inner_type {
120            "content_block_start" => {
121                let block_type = parse_block_type(event.get("content_block")?);
122                Some(PartialMessageEvent::BlockStart { index, block_type })
123            }
124            "content_block_delta" => {
125                let delta = parse_block_delta(event.get("delta")?);
126                Some(PartialMessageEvent::BlockDelta { index, delta })
127            }
128            "content_block_stop" => Some(PartialMessageEvent::BlockStop { index }),
129            _ => None,
130        }
131    }
132}
133
134/// A decoded partial-message event from a streaming `claude` call.
135///
136/// Surfaced by [`StreamEvent::partial_message`] when `--include-partial-messages`
137/// is set. The three variants correspond to the Anthropic streaming content-block
138/// lifecycle: a block starts, gets one or more deltas, then stops.
139#[cfg(feature = "json")]
140#[derive(Debug, Clone, PartialEq, Eq)]
141pub enum PartialMessageEvent {
142    /// A new content block is starting. `block_type` says what kind.
143    BlockStart {
144        /// Position of this block within the assistant message.
145        index: u32,
146        /// What kind of block is starting (text, thinking, tool use, ...).
147        block_type: BlockType,
148    },
149    /// Incremental content for an in-progress block.
150    BlockDelta {
151        /// Index of the block this delta applies to (matches a prior [`BlockStart`]).
152        ///
153        /// [`BlockStart`]: PartialMessageEvent::BlockStart
154        index: u32,
155        /// The incremental payload.
156        delta: BlockDelta,
157    },
158    /// The block at `index` is complete.
159    BlockStop {
160        /// Index of the block that just finished.
161        index: u32,
162    },
163}
164
165/// The kind of content block reported by a [`PartialMessageEvent::BlockStart`].
166///
167/// Mirrors the `content_block.type` field from the Anthropic streaming API.
168/// New block kinds added upstream surface as [`BlockType::Other`] -- callers
169/// can still recover the type name from the carried string.
170#[cfg(feature = "json")]
171#[derive(Debug, Clone, PartialEq, Eq)]
172pub enum BlockType {
173    /// Regular assistant text -- followed by `text_delta` deltas.
174    Text,
175    /// Extended-thinking block -- followed by `thinking_delta` deltas.
176    Thinking,
177    /// A tool invocation -- followed by `input_json_delta` deltas streaming the JSON input.
178    ToolUse {
179        /// Tool-call id, used to correlate the eventual tool result.
180        id: String,
181        /// Name of the tool being called.
182        name: String,
183    },
184    /// Any block type not yet modelled. Carries the raw `type` string.
185    Other(String),
186}
187
188/// The incremental payload carried by a [`PartialMessageEvent::BlockDelta`].
189///
190/// Mirrors the `delta.type` field from the Anthropic streaming API.
191/// Less-common delta kinds (signature, citations, compaction, ...) collapse to
192/// [`BlockDelta::Other`]; callers that need them can fall back to
193/// [`StreamEvent::data`].
194#[cfg(feature = "json")]
195#[derive(Debug, Clone, PartialEq, Eq)]
196pub enum BlockDelta {
197    /// Chunk of assistant text.
198    Text(String),
199    /// Chunk of extended-thinking text.
200    Thinking(String),
201    /// Chunk of streaming tool-input JSON. Concatenate across deltas to
202    /// reconstruct the full input -- individual chunks are not standalone JSON.
203    InputJson(String),
204    /// Any delta type not modelled above (e.g. `signature_delta`,
205    /// `citations_delta`). Read from [`StreamEvent::data`] for the raw payload.
206    Other,
207}
208
209#[cfg(feature = "json")]
210fn parse_block_type(content_block: &serde_json::Value) -> BlockType {
211    let Some(ty) = content_block
212        .get("type")
213        .and_then(serde_json::Value::as_str)
214    else {
215        return BlockType::Other(String::new());
216    };
217    match ty {
218        "text" => BlockType::Text,
219        "thinking" => BlockType::Thinking,
220        "tool_use" => {
221            let id = content_block
222                .get("id")
223                .and_then(serde_json::Value::as_str)
224                .unwrap_or("")
225                .to_string();
226            let name = content_block
227                .get("name")
228                .and_then(serde_json::Value::as_str)
229                .unwrap_or("")
230                .to_string();
231            BlockType::ToolUse { id, name }
232        }
233        other => BlockType::Other(other.to_string()),
234    }
235}
236
237#[cfg(feature = "json")]
238fn parse_block_delta(delta: &serde_json::Value) -> BlockDelta {
239    let Some(ty) = delta.get("type").and_then(serde_json::Value::as_str) else {
240        return BlockDelta::Other;
241    };
242    match ty {
243        "text_delta" => delta
244            .get("text")
245            .and_then(serde_json::Value::as_str)
246            .map(|s| BlockDelta::Text(s.to_string()))
247            .unwrap_or(BlockDelta::Other),
248        "thinking_delta" => delta
249            .get("thinking")
250            .and_then(serde_json::Value::as_str)
251            .map(|s| BlockDelta::Thinking(s.to_string()))
252            .unwrap_or(BlockDelta::Other),
253        "input_json_delta" => delta
254            .get("partial_json")
255            .and_then(serde_json::Value::as_str)
256            .map(|s| BlockDelta::InputJson(s.to_string()))
257            .unwrap_or(BlockDelta::Other),
258        _ => BlockDelta::Other,
259    }
260}
261
262/// Execute a command with streaming output, calling a handler for each NDJSON line.
263///
264/// This spawns the claude process and reads stdout line-by-line, parsing each
265/// as a JSON event and passing it to the handler. Useful for progress tracking
266/// and real-time output processing.
267///
268/// # Example
269///
270/// ```no_run
271/// use claude_wrapper::{Claude, QueryCommand, OutputFormat};
272/// use claude_wrapper::streaming::{StreamEvent, stream_query};
273///
274/// # async fn example() -> claude_wrapper::Result<()> {
275/// let claude = Claude::builder().build()?;
276///
277/// let cmd = QueryCommand::new("explain quicksort")
278///     .output_format(OutputFormat::StreamJson);
279///
280/// let output = stream_query(&claude, &cmd, |event: StreamEvent| {
281///     if let Some(t) = event.event_type() {
282///         println!("[{t}] {:?}", event.data);
283///     }
284/// }).await?;
285/// # Ok(())
286/// # }
287/// ```
288#[cfg(all(feature = "json", feature = "async"))]
289pub async fn stream_query<F>(
290    claude: &Claude,
291    cmd: &crate::command::query::QueryCommand,
292    handler: F,
293) -> Result<CommandOutput>
294where
295    F: FnMut(StreamEvent),
296{
297    stream_query_impl(claude, cmd, handler, claude.timeout).await
298}
299
300/// Unified streaming implementation with optional timeout.
301///
302/// Reads stderr concurrently in a background task so a chatty child
303/// cannot deadlock by filling the stderr pipe buffer, and so any
304/// captured stderr is available even on timeout or IO error.
305///
306/// On timeout, the child is killed and reaped (`kill().await` sends
307/// SIGKILL and waits), and whatever stderr was produced is logged at
308/// warn level. The returned `Error::Timeout` does not carry partial
309/// output -- streamed stdout events were already dispatched to the
310/// handler as they arrived.
311#[cfg(all(feature = "json", feature = "async"))]
312async fn stream_query_impl<F>(
313    claude: &Claude,
314    cmd: &crate::command::query::QueryCommand,
315    mut handler: F,
316    timeout: Option<Duration>,
317) -> Result<CommandOutput>
318where
319    F: FnMut(StreamEvent),
320{
321    use crate::command::ClaudeCommand;
322
323    let args = cmd.args();
324
325    let mut command_args = Vec::new();
326    command_args.extend(claude.global_args.clone());
327    command_args.extend(args);
328
329    debug!(
330        binary = %claude.binary.display(),
331        args = ?command_args,
332        timeout = ?timeout,
333        "streaming claude command"
334    );
335
336    let mut cmd = Command::new(&claude.binary);
337    cmd.args(&command_args)
338        .env_remove("CLAUDECODE")
339        .envs(&claude.env)
340        .stdout(std::process::Stdio::piped())
341        .stderr(std::process::Stdio::piped())
342        .stdin(std::process::Stdio::null());
343
344    if let Some(ref dir) = claude.working_dir {
345        cmd.current_dir(dir);
346    }
347
348    let mut child = cmd.spawn().map_err(|e| Error::Io {
349        message: format!("failed to spawn claude: {e}"),
350        source: e,
351        working_dir: claude.working_dir.clone(),
352    })?;
353
354    let stdout = child.stdout.take().expect("stdout was piped");
355    let mut stderr = child.stderr.take().expect("stderr was piped");
356
357    let mut reader = BufReader::new(stdout).lines();
358
359    // Run stdout line reading and stderr draining concurrently so a
360    // chatty child can't deadlock by filling the stderr pipe buffer.
361    // tokio::join! polls both futures on the same task (no tokio::spawn
362    // needed, so we avoid pulling in the `rt` feature).
363    let drain = drain_stderr(&mut stderr);
364    let read_future = read_lines(&mut reader, &mut handler, claude.working_dir.clone());
365    let combined = async {
366        let (line_result, stderr_str) = tokio::join!(read_future, drain);
367        (line_result, stderr_str)
368    };
369
370    let (line_result, stderr_str) = match timeout {
371        Some(d) => match tokio::time::timeout(d, combined).await {
372            Ok(pair) => pair,
373            Err(_) => {
374                // Timeout: kill the child (reaps via start_kill + wait)
375                // and try to drain whatever stderr remains. kill() only
376                // targets the direct child, so a subprocess tree holding
377                // our pipe fds could block the drain -- cap it with a
378                // short deadline.
379                let _ = child.kill().await;
380                let drain_budget = Duration::from_millis(200);
381                let stderr_str = tokio::time::timeout(drain_budget, drain_stderr(&mut stderr))
382                    .await
383                    .unwrap_or_default();
384                if !stderr_str.is_empty() {
385                    warn!(stderr = %stderr_str, "stderr from timed-out streaming process");
386                }
387                return Err(Error::Timeout {
388                    timeout_seconds: d.as_secs(),
389                });
390            }
391        },
392        None => combined.await,
393    };
394
395    // If reading lines failed partway through (IO error, not timeout),
396    // clean up the child before returning.
397    if let Err(e) = line_result {
398        let _ = child.kill().await;
399        return Err(e);
400    }
401
402    let status = child.wait().await.map_err(|e| Error::Io {
403        message: "failed to wait for claude process".to_string(),
404        source: e,
405        working_dir: claude.working_dir.clone(),
406    })?;
407
408    let exit_code = status.code().unwrap_or(-1);
409
410    if !status.success() {
411        return Err(Error::CommandFailed {
412            command: format!("{} {}", claude.binary.display(), command_args.join(" ")),
413            exit_code,
414            stdout: String::new(),
415            stderr: stderr_str,
416            working_dir: claude.working_dir.clone(),
417        });
418    }
419
420    Ok(CommandOutput {
421        stdout: String::new(), // already consumed via streaming
422        stderr: stderr_str,
423        exit_code,
424        success: true,
425    })
426}
427
428#[cfg(all(feature = "json", feature = "async"))]
429async fn drain_stderr(stderr: &mut ChildStderr) -> String {
430    let mut buf = Vec::new();
431    let _ = stderr.read_to_end(&mut buf).await;
432    String::from_utf8_lossy(&buf).into_owned()
433}
434
435#[cfg(all(feature = "json", feature = "async"))]
436async fn read_lines<F>(
437    reader: &mut tokio::io::Lines<BufReader<tokio::process::ChildStdout>>,
438    handler: &mut F,
439    working_dir: Option<std::path::PathBuf>,
440) -> Result<()>
441where
442    F: FnMut(StreamEvent),
443{
444    while let Some(line) = reader.next_line().await.map_err(|e| Error::Io {
445        message: "failed to read stdout line".to_string(),
446        source: e,
447        working_dir: working_dir.clone(),
448    })? {
449        if line.trim().is_empty() {
450            continue;
451        }
452        match serde_json::from_str::<StreamEvent>(&line) {
453            Ok(event) => handler(event),
454            Err(e) => {
455                debug!(line = %line, error = %e, "failed to parse stream event, skipping");
456            }
457        }
458    }
459
460    Ok(())
461}
462
463// ---------- sync streaming ----------
464
465/// Blocking mirror of [`stream_query`]. Reads NDJSON lines from the
466/// child's stdout on a worker thread, dispatches each parsed event
467/// to `handler` on the caller's thread, and drains stderr on a
468/// separate worker thread so the child can't deadlock on a full pipe.
469///
470/// Requires both `sync` and `json` features.
471///
472/// The handler is invoked on the caller's thread — no `Send` bound —
473/// so it can capture non-`Send` state. If a timeout is configured on
474/// the [`Claude`] client, the child is SIGKILLed and reaped once the
475/// deadline passes; partial events already dispatched to the handler
476/// are not rolled back.
477///
478/// # Example
479///
480/// ```no_run
481/// # #[cfg(all(feature = "sync", feature = "json"))]
482/// # {
483/// use claude_wrapper::{Claude, OutputFormat, QueryCommand};
484/// use claude_wrapper::streaming::{StreamEvent, stream_query_sync};
485///
486/// # fn example() -> claude_wrapper::Result<()> {
487/// let claude = Claude::builder().build()?;
488/// let cmd = QueryCommand::new("explain quicksort")
489///     .output_format(OutputFormat::StreamJson);
490///
491/// stream_query_sync(&claude, &cmd, |event: StreamEvent| {
492///     if let Some(t) = event.event_type() {
493///         println!("[{t}] {:?}", event.data);
494///     }
495/// })?;
496/// # Ok(())
497/// # }
498/// # }
499/// ```
500#[cfg(all(feature = "sync", feature = "json"))]
501pub fn stream_query_sync<F>(
502    claude: &Claude,
503    cmd: &crate::command::query::QueryCommand,
504    mut handler: F,
505) -> Result<CommandOutput>
506where
507    F: FnMut(StreamEvent),
508{
509    use std::io::{BufRead as _, Read as _};
510    use std::process::{Command as StdCommand, Stdio};
511    use std::sync::mpsc;
512    use std::thread;
513    use std::time::Instant;
514
515    use crate::command::ClaudeCommand;
516
517    let args = cmd.args();
518    let mut command_args = Vec::new();
519    command_args.extend(claude.global_args.clone());
520    command_args.extend(args);
521
522    debug!(
523        binary = %claude.binary.display(),
524        args = ?command_args,
525        timeout = ?claude.timeout,
526        "streaming claude command (sync)"
527    );
528
529    let mut cmd_builder = StdCommand::new(&claude.binary);
530    cmd_builder
531        .args(&command_args)
532        .env_remove("CLAUDECODE")
533        .env_remove("CLAUDE_CODE_ENTRYPOINT")
534        .envs(&claude.env)
535        .stdin(Stdio::null())
536        .stdout(Stdio::piped())
537        .stderr(Stdio::piped());
538
539    if let Some(ref dir) = claude.working_dir {
540        cmd_builder.current_dir(dir);
541    }
542
543    let mut child = cmd_builder.spawn().map_err(|e| Error::Io {
544        message: format!("failed to spawn claude: {e}"),
545        source: e,
546        working_dir: claude.working_dir.clone(),
547    })?;
548
549    let stdout = child.stdout.take().expect("stdout was piped");
550    let stderr = child.stderr.take().expect("stderr was piped");
551
552    // Reader thread: parse NDJSON lines and push StreamEvents through
553    // the channel. Handler runs on the caller's thread so it doesn't
554    // need Send. Bubbles IO errors out via the thread's return value.
555    let (tx, rx) = mpsc::channel::<StreamEvent>();
556    let reader_wd = claude.working_dir.clone();
557    let reader_thread = thread::spawn(move || -> Result<()> {
558        let reader = std::io::BufReader::new(stdout);
559        for line_res in reader.lines() {
560            let line = line_res.map_err(|e| Error::Io {
561                message: "failed to read stdout line".to_string(),
562                source: e,
563                working_dir: reader_wd.clone(),
564            })?;
565            if line.trim().is_empty() {
566                continue;
567            }
568            match serde_json::from_str::<StreamEvent>(&line) {
569                Ok(event) => {
570                    if tx.send(event).is_err() {
571                        // Receiver gone — main thread has bailed out.
572                        return Ok(());
573                    }
574                }
575                Err(e) => {
576                    debug!(line = %line, error = %e, "failed to parse stream event, skipping");
577                }
578            }
579        }
580        Ok(())
581    });
582
583    let stderr_thread = thread::spawn(move || -> String {
584        let mut buf = Vec::new();
585        let mut stderr = stderr;
586        let _ = stderr.read_to_end(&mut buf);
587        String::from_utf8_lossy(&buf).into_owned()
588    });
589
590    // Main loop: dispatch events on the caller's thread, honouring the
591    // configured timeout. Break on disconnect (reader done) or timeout.
592    let deadline = claude.timeout.map(|d| Instant::now() + d);
593    let mut timed_out = false;
594
595    loop {
596        let recv_result = match deadline {
597            Some(d) => {
598                let now = Instant::now();
599                if now >= d {
600                    timed_out = true;
601                    break;
602                }
603                rx.recv_timeout(d - now)
604            }
605            None => rx.recv().map_err(|_| mpsc::RecvTimeoutError::Disconnected),
606        };
607
608        match recv_result {
609            Ok(event) => handler(event),
610            Err(mpsc::RecvTimeoutError::Timeout) => {
611                timed_out = true;
612                break;
613            }
614            Err(mpsc::RecvTimeoutError::Disconnected) => break,
615        }
616    }
617
618    if timed_out {
619        let _ = child.kill();
620        let _ = child.wait();
621        // Both worker threads can block indefinitely if an orphaned
622        // grandchild inherited our pipe fds and keeps the write end
623        // open (e.g. a `bash` script whose `sleep` subprocess outlives
624        // the SIGKILLed shell). Cap the joins so the timeout error
625        // still returns promptly; any thread that misses the deadline
626        // leaks its JoinHandle, which is acceptable for this edge.
627        let budget = Duration::from_millis(200);
628        let stderr_str = join_with_budget(stderr_thread, budget).unwrap_or_default();
629        let _ = join_with_budget(reader_thread, budget);
630        if !stderr_str.is_empty() {
631            warn!(stderr = %stderr_str, "stderr from timed-out streaming process");
632        }
633        return Err(Error::Timeout {
634            timeout_seconds: claude.timeout.map(|d| d.as_secs()).unwrap_or_default(),
635        });
636    }
637
638    // Normal completion: collect reader result (may carry IO error).
639    let reader_result = reader_thread.join().unwrap_or(Ok(()));
640    if let Err(e) = reader_result {
641        let _ = child.kill();
642        let _ = child.wait();
643        let _ = stderr_thread.join();
644        return Err(e);
645    }
646
647    let status = child.wait().map_err(|e| Error::Io {
648        message: "failed to wait for claude process".to_string(),
649        source: e,
650        working_dir: claude.working_dir.clone(),
651    })?;
652    let stderr_str = stderr_thread.join().unwrap_or_default();
653    let exit_code = status.code().unwrap_or(-1);
654
655    if !status.success() {
656        return Err(Error::CommandFailed {
657            command: format!("{} {}", claude.binary.display(), command_args.join(" ")),
658            exit_code,
659            stdout: String::new(),
660            stderr: stderr_str,
661            working_dir: claude.working_dir.clone(),
662        });
663    }
664
665    Ok(CommandOutput {
666        stdout: String::new(),
667        stderr: stderr_str,
668        exit_code,
669        success: true,
670    })
671}
672
673/// Join a worker thread with a time budget. Returns `Some(value)` if
674/// the thread finished in time, `None` if the deadline passed first.
675/// A missed deadline leaks the `JoinHandle`; the thread completes
676/// eventually and its value is dropped.
677#[cfg(all(feature = "sync", feature = "json"))]
678fn join_with_budget<T: Send + 'static>(
679    handle: std::thread::JoinHandle<T>,
680    budget: Duration,
681) -> Option<T> {
682    use std::sync::mpsc;
683    use std::thread;
684
685    let (tx, rx) = mpsc::channel::<T>();
686    thread::spawn(move || {
687        if let Ok(v) = handle.join() {
688            let _ = tx.send(v);
689        }
690    });
691    rx.recv_timeout(budget).ok()
692}
693
694#[cfg(all(test, feature = "json"))]
695mod tests {
696    use super::*;
697    use serde_json::json;
698
699    fn parse(v: serde_json::Value) -> StreamEvent {
700        serde_json::from_value(v).expect("valid StreamEvent")
701    }
702
703    fn wrap(inner: serde_json::Value) -> StreamEvent {
704        parse(json!({
705            "type": "stream_event",
706            "event": inner,
707            "session_id": "sess-1",
708            "parent_tool_use_id": null,
709            "uuid": "11111111-1111-1111-1111-111111111111"
710        }))
711    }
712
713    #[test]
714    fn partial_message_text_block_lifecycle() {
715        let start = wrap(json!({
716            "type": "content_block_start",
717            "index": 0,
718            "content_block": { "type": "text", "text": "" }
719        }));
720        assert_eq!(
721            start.partial_message(),
722            Some(PartialMessageEvent::BlockStart {
723                index: 0,
724                block_type: BlockType::Text,
725            })
726        );
727
728        let delta = wrap(json!({
729            "type": "content_block_delta",
730            "index": 0,
731            "delta": { "type": "text_delta", "text": "Hello" }
732        }));
733        assert_eq!(
734            delta.partial_message(),
735            Some(PartialMessageEvent::BlockDelta {
736                index: 0,
737                delta: BlockDelta::Text("Hello".into()),
738            })
739        );
740
741        let stop = wrap(json!({ "type": "content_block_stop", "index": 0 }));
742        assert_eq!(
743            stop.partial_message(),
744            Some(PartialMessageEvent::BlockStop { index: 0 })
745        );
746    }
747
748    #[test]
749    fn partial_message_thinking_block_lifecycle() {
750        let start = wrap(json!({
751            "type": "content_block_start",
752            "index": 1,
753            "content_block": { "type": "thinking", "thinking": "", "signature": "" }
754        }));
755        assert_eq!(
756            start.partial_message(),
757            Some(PartialMessageEvent::BlockStart {
758                index: 1,
759                block_type: BlockType::Thinking,
760            })
761        );
762
763        let delta = wrap(json!({
764            "type": "content_block_delta",
765            "index": 1,
766            "delta": { "type": "thinking_delta", "thinking": "weighing options" }
767        }));
768        assert_eq!(
769            delta.partial_message(),
770            Some(PartialMessageEvent::BlockDelta {
771                index: 1,
772                delta: BlockDelta::Thinking("weighing options".into()),
773            })
774        );
775
776        let stop = wrap(json!({ "type": "content_block_stop", "index": 1 }));
777        assert_eq!(
778            stop.partial_message(),
779            Some(PartialMessageEvent::BlockStop { index: 1 })
780        );
781    }
782
783    #[test]
784    fn partial_message_tool_use_block_carries_id_and_name() {
785        let start = wrap(json!({
786            "type": "content_block_start",
787            "index": 2,
788            "content_block": {
789                "type": "tool_use",
790                "id": "toolu_abc",
791                "name": "Bash",
792                "input": {}
793            }
794        }));
795        assert_eq!(
796            start.partial_message(),
797            Some(PartialMessageEvent::BlockStart {
798                index: 2,
799                block_type: BlockType::ToolUse {
800                    id: "toolu_abc".into(),
801                    name: "Bash".into(),
802                },
803            })
804        );
805
806        let delta = wrap(json!({
807            "type": "content_block_delta",
808            "index": 2,
809            "delta": { "type": "input_json_delta", "partial_json": "{\"cmd\":" }
810        }));
811        assert_eq!(
812            delta.partial_message(),
813            Some(PartialMessageEvent::BlockDelta {
814                index: 2,
815                delta: BlockDelta::InputJson("{\"cmd\":".into()),
816            })
817        );
818    }
819
820    #[test]
821    fn partial_message_unknown_kinds_fall_through_to_other() {
822        let unknown_block = wrap(json!({
823            "type": "content_block_start",
824            "index": 3,
825            "content_block": { "type": "redacted_thinking", "data": "..." }
826        }));
827        assert_eq!(
828            unknown_block.partial_message(),
829            Some(PartialMessageEvent::BlockStart {
830                index: 3,
831                block_type: BlockType::Other("redacted_thinking".into()),
832            })
833        );
834
835        let unknown_delta = wrap(json!({
836            "type": "content_block_delta",
837            "index": 3,
838            "delta": { "type": "signature_delta", "signature": "sig" }
839        }));
840        assert_eq!(
841            unknown_delta.partial_message(),
842            Some(PartialMessageEvent::BlockDelta {
843                index: 3,
844                delta: BlockDelta::Other,
845            })
846        );
847    }
848
849    #[test]
850    fn partial_message_returns_none_for_non_partial_events() {
851        let result = parse(json!({
852            "type": "result",
853            "result": "done",
854            "session_id": "sess-1",
855            "total_cost_usd": 0.01
856        }));
857        assert!(result.partial_message().is_none());
858
859        let assistant = parse(json!({
860            "type": "assistant",
861            "message": { "role": "assistant", "content": [] },
862            "session_id": "sess-1"
863        }));
864        assert!(assistant.partial_message().is_none());
865
866        let message_start = wrap(json!({
867            "type": "message_start",
868            "message": { "id": "msg_1", "role": "assistant", "content": [] }
869        }));
870        assert!(message_start.partial_message().is_none());
871    }
872
873    #[test]
874    fn partial_message_accepts_unwrapped_event() {
875        let raw = parse(json!({
876            "type": "content_block_delta",
877            "index": 0,
878            "delta": { "type": "text_delta", "text": "hi" }
879        }));
880        assert_eq!(
881            raw.partial_message(),
882            Some(PartialMessageEvent::BlockDelta {
883                index: 0,
884                delta: BlockDelta::Text("hi".into()),
885            })
886        );
887    }
888}