use std::pin::Pin;
use std::task::{Context, Poll};
use futures::Stream;
use pin_project_lite::pin_project;
use serde::Serialize;
use super::span::LangfuseSpan;
type TransformFn = Box<dyn Fn(&[String]) -> String + Send + Sync>;
pin_project! {
pub struct ObservingStream<S: Stream> {
#[pin]
inner: S,
span: Option<LangfuseSpan>,
collected: Vec<String>,
transform: Option<TransformFn>,
}
}
impl<S: Stream> ObservingStream<S> {
#[must_use]
pub fn new(span: LangfuseSpan, inner: S) -> Self {
Self {
inner,
span: Some(span),
collected: Vec::new(),
transform: None,
}
}
#[must_use]
pub fn with_transform(
span: LangfuseSpan,
inner: S,
transform: impl Fn(&[String]) -> String + Send + Sync + 'static,
) -> Self {
Self {
inner,
span: Some(span),
collected: Vec::new(),
transform: Some(Box::new(transform)),
}
}
}
impl<S: Stream> std::fmt::Debug for ObservingStream<S> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ObservingStream")
.field("collected_count", &self.collected.len())
.finish()
}
}
impl<S> Stream for ObservingStream<S>
where
S: Stream,
S::Item: Serialize,
{
type Item = S::Item;
fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.project();
match this.inner.poll_next(cx) {
Poll::Ready(Some(item)) => {
if let Ok(json) = serde_json::to_string(&item) {
this.collected.push(json);
}
Poll::Ready(Some(item))
}
Poll::Ready(None) => {
if let Some(span) = this.span.take() {
if let Some(transform) = this.transform.as_ref() {
let output = transform(this.collected);
span.set_output(&output);
} else {
let output = serde_json::json!(this.collected);
span.set_output(&output);
}
span.end();
}
Poll::Ready(None)
}
Poll::Pending => Poll::Pending,
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}
pub struct ObservingIterator<I: Iterator> {
inner: I,
span: Option<LangfuseSpan>,
collected: Vec<String>,
transform: Option<TransformFn>,
}
impl<I: Iterator> ObservingIterator<I> {
#[must_use]
pub fn new(span: LangfuseSpan, inner: I) -> Self {
Self {
inner,
span: Some(span),
collected: Vec::new(),
transform: None,
}
}
#[must_use]
pub fn with_transform(
span: LangfuseSpan,
inner: I,
transform: impl Fn(&[String]) -> String + Send + Sync + 'static,
) -> Self {
Self {
inner,
span: Some(span),
collected: Vec::new(),
transform: Some(Box::new(transform)),
}
}
}
impl<I: Iterator> std::fmt::Debug for ObservingIterator<I> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("ObservingIterator")
.field("collected_count", &self.collected.len())
.finish()
}
}
impl<I> Iterator for ObservingIterator<I>
where
I: Iterator,
I::Item: Serialize,
{
type Item = I::Item;
fn next(&mut self) -> Option<Self::Item> {
match self.inner.next() {
Some(item) => {
if let Ok(json) = serde_json::to_string(&item) {
self.collected.push(json);
}
Some(item)
}
None => {
if let Some(span) = self.span.take() {
if let Some(transform) = self.transform.as_ref() {
let output = transform(&self.collected);
span.set_output(&output);
} else {
let output = serde_json::json!(self.collected);
span.set_output(&output);
}
span.end();
}
None
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
self.inner.size_hint()
}
}