use std::fmt;
use serde::{Deserialize, Serialize};
pub use crate::model::{ContentPart, ContentPartKind, Message, Role};
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ToolCall {
pub id: String,
pub name: String,
pub arguments: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct ThreadMarker {
pub name: String,
pub nonce: String,
}
pub struct PromptyStream {
pub name: String,
pub items: Vec<serde_json::Value>,
inner: Option<std::pin::Pin<Box<dyn futures::Stream<Item = serde_json::Value> + Send>>>,
exhausted: bool,
}
impl PromptyStream {
pub fn new(name: impl Into<String>) -> Self {
Self {
name: name.into(),
items: Vec::new(),
inner: None,
exhausted: false,
}
}
pub fn from_stream(
name: impl Into<String>,
inner: impl futures::Stream<Item = serde_json::Value> + Send + 'static,
) -> Self {
Self {
name: name.into(),
items: Vec::new(),
inner: Some(Box::pin(inner)),
exhausted: false,
}
}
pub fn push(&mut self, chunk: serde_json::Value) {
self.items.push(chunk);
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn is_streaming(&self) -> bool {
self.inner.is_some()
}
pub fn flush(&self) {
if self.items.is_empty() {
return;
}
let span = crate::tracing::Tracer::start(&self.name);
span.emit(
"signature",
&serde_json::Value::String(format!("{}.PromptyStream", self.name)),
);
span.emit("inputs", &serde_json::Value::String("None".into()));
span.emit("result", &serde_json::Value::Array(self.items.clone()));
span.end();
}
pub fn collect_text(&self) -> String {
let mut text = String::new();
for item in &self.items {
if let Some(content) = item
.pointer("/choices/0/delta/content")
.and_then(|v| v.as_str())
{
text.push_str(content);
}
}
text
}
pub fn collect_tool_calls(&self) -> Vec<crate::types::ToolCall> {
use std::collections::BTreeMap;
let mut calls: BTreeMap<usize, (String, String, String)> = BTreeMap::new(); for item in &self.items {
if let Some(tcs) = item
.pointer("/choices/0/delta/tool_calls")
.and_then(|v| v.as_array())
{
for tc in tcs {
let idx = tc.get("index").and_then(|v| v.as_u64()).unwrap_or(0) as usize;
let id = tc
.get("id")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let name = tc
.pointer("/function/name")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let args = tc
.pointer("/function/arguments")
.and_then(|v| v.as_str())
.unwrap_or("")
.to_string();
let entry = calls.entry(idx);
let (existing_id, existing_name, existing_args) =
entry.or_insert_with(|| (String::new(), String::new(), String::new()));
if !id.is_empty() {
*existing_id = id;
}
if !name.is_empty() {
*existing_name = name;
}
existing_args.push_str(&args);
}
}
}
calls
.into_iter()
.map(|(_, (id, name, arguments))| crate::types::ToolCall {
id,
name,
arguments,
})
.collect()
}
pub async fn collect_all(&mut self) -> &[serde_json::Value] {
if let Some(mut inner) = self.inner.take() {
use futures::StreamExt;
while let Some(chunk) = inner.next().await {
self.items.push(chunk);
}
self.exhausted = true;
}
&self.items
}
}
impl futures::Stream for PromptyStream {
type Item = serde_json::Value;
fn poll_next(
mut self: std::pin::Pin<&mut Self>,
cx: &mut std::task::Context<'_>,
) -> std::task::Poll<Option<Self::Item>> {
if let Some(ref mut inner) = self.inner {
match inner.as_mut().poll_next(cx) {
std::task::Poll::Ready(Some(chunk)) => {
self.items.push(chunk.clone());
std::task::Poll::Ready(Some(chunk))
}
std::task::Poll::Ready(None) => {
self.exhausted = true;
self.flush();
self.inner = None;
std::task::Poll::Ready(None)
}
std::task::Poll::Pending => std::task::Poll::Pending,
}
} else {
std::task::Poll::Ready(None)
}
}
}
impl Drop for PromptyStream {
fn drop(&mut self) {
if !self.exhausted && !self.items.is_empty() {
self.flush();
}
}
}
#[derive(Debug, Clone)]
pub enum StreamChunk {
Text(String),
Thinking(String),
Tool(ToolCall),
Error(String),
}
pub async fn consume_stream_chunks(
stream: impl futures::Stream<Item = StreamChunk> + Unpin,
on_token: Option<&dyn Fn(&str)>,
) -> (Vec<ToolCall>, String) {
use futures::StreamExt;
let mut tool_calls = Vec::new();
let mut text_parts = Vec::new();
futures::pin_mut!(stream);
while let Some(chunk) = stream.next().await {
match chunk {
StreamChunk::Text(t) => {
if let Some(cb) = on_token {
cb(&t);
}
text_parts.push(t);
}
StreamChunk::Thinking(_) => {
}
StreamChunk::Tool(tc) => {
tool_calls.push(tc);
}
StreamChunk::Error(_) => {
break;
}
}
}
(tool_calls, text_parts.join(""))
}
impl fmt::Debug for PromptyStream {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("PromptyStream")
.field("name", &self.name)
.field("items_len", &self.items.len())
.field("is_streaming", &self.inner.is_some())
.field("exhausted", &self.exhausted)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_message_text() {
let msg = Message::with_text(Role::User, "Hello");
assert_eq!(msg.role, Role::User);
assert_eq!(msg.text_content(), "Hello");
assert!(!msg.has_rich_content());
}
#[test]
fn test_message_multipart_text() {
let msg = Message {
role: Role::Assistant,
parts: vec![ContentPart::text("Hello "), ContentPart::text("world")],
..Default::default()
};
assert_eq!(msg.text_content(), "Hello world");
}
#[test]
fn test_message_rich_content() {
let msg = Message {
role: Role::User,
parts: vec![
ContentPart::text("Look at this:"),
ContentPart::image("https://example.com/img.png", Some("high".into()), None),
],
..Default::default()
};
assert!(msg.has_rich_content());
}
#[test]
fn test_to_text_content_single() {
use crate::model::MessageHelpers;
let msg = Message::with_text(Role::User, "simple");
assert_eq!(
msg.to_text_content(),
serde_json::Value::String("simple".into())
);
}
#[test]
fn test_to_text_content_multipart() {
use crate::model::MessageHelpers;
let msg = Message {
role: Role::User,
parts: vec![
ContentPart::text("Hello"),
ContentPart::image("data:image/png;base64,abc", None, None),
],
..Default::default()
};
let content = msg.to_text_content();
assert!(content.is_array());
}
#[test]
fn test_tool_result_message() {
let msg = Message::tool_result("call_123", r#"{"temp": 72}"#);
assert_eq!(msg.role, Role::Tool);
assert_eq!(msg.text_content(), r#"{"temp": 72}"#);
assert_eq!(
msg.metadata.get("tool_call_id").and_then(|v| v.as_str()),
Some("call_123")
);
}
#[test]
fn test_role_display() {
assert_eq!(Role::System.to_string(), "system");
assert_eq!(Role::Assistant.to_string(), "assistant");
}
#[test]
fn test_role_from_str() {
assert_eq!(Role::from_str_ignore_case("System"), Some(Role::System));
assert_eq!(Role::from_str_ignore_case("USER"), Some(Role::User));
assert_eq!(Role::from_str_opt("unknown"), None);
}
#[test]
fn test_message_serde_roundtrip() {
let msg = Message::with_text(Role::User, "Hello");
let json = serde_json::to_string(&msg).unwrap();
let deserialized: Message = serde_json::from_str(&json).unwrap();
assert_eq!(msg, deserialized);
}
#[test]
fn test_tool_call_serde() {
let tc = ToolCall {
id: "call_abc".into(),
name: "get_weather".into(),
arguments: r#"{"city":"Seattle"}"#.into(),
};
let json = serde_json::to_value(&tc).unwrap();
assert_eq!(json["id"], "call_abc");
assert_eq!(json["name"], "get_weather");
}
#[test]
fn test_prompty_stream_new() {
let stream = PromptyStream::new("test-stream");
assert_eq!(stream.name, "test-stream");
assert!(stream.items.is_empty());
}
#[test]
fn test_prompty_stream_push() {
let mut stream = PromptyStream::new("test");
stream.push(serde_json::json!({"chunk": 1}));
stream.push(serde_json::json!({"chunk": 2}));
assert_eq!(stream.items.len(), 2);
assert_eq!(stream.items[0]["chunk"], 1);
assert_eq!(stream.items[1]["chunk"], 2);
}
#[test]
fn test_prompty_stream_debug() {
let mut stream = PromptyStream::new("debug-test");
stream.push(serde_json::json!("chunk"));
let dbg = format!("{:?}", stream);
assert!(dbg.contains("debug-test"));
assert!(dbg.contains("items_len: 1"));
}
#[test]
fn test_message_empty_text_content() {
let msg = Message {
role: Role::User,
parts: vec![],
..Default::default()
};
assert_eq!(msg.text_content(), "");
assert!(!msg.has_rich_content());
}
#[test]
fn test_tool_result_metadata_fields() {
let msg = Message::tool_result("call_99", "result text");
assert_eq!(msg.metadata["tool_call_id"], "call_99");
assert_eq!(msg.role, Role::Tool);
}
}