use std::io::{BufRead, BufReader};
use crate::{ChatResponse, Error, Result};
pub struct ChatStream {
rx: tokio::sync::Mutex<tokio::sync::mpsc::Receiver<Result<ChatResponse>>>,
}
impl ChatStream {
pub(crate) fn new(rx: tokio::sync::mpsc::Receiver<Result<ChatResponse>>) -> Self {
Self {
rx: tokio::sync::Mutex::new(rx),
}
}
pub async fn next(&self) -> Option<Result<ChatResponse>> {
self.rx.lock().await.recv().await
}
pub async fn collect(self) -> Result<Vec<ChatResponse>> {
let mut out = Vec::new();
let mut rx = self.rx.into_inner();
while let Some(item) = rx.recv().await {
match item {
Ok(v) => out.push(v),
Err(e) => return Err(e),
}
}
Ok(out)
}
}
pub struct ChatStreamBlocking {
lines: std::io::Lines<BufReader<reqwest::blocking::Response>>,
}
impl ChatStreamBlocking {
pub(crate) fn new(response: reqwest::blocking::Response) -> Self {
Self {
lines: BufReader::new(response).lines(),
}
}
}
impl Iterator for ChatStreamBlocking {
type Item = Result<ChatResponse>;
fn next(&mut self) -> Option<Self::Item> {
loop {
match self.lines.next() {
None => return None,
Some(Err(e)) => return Some(Err(Error::StreamError(e.to_string()))),
Some(Ok(line)) => {
let trimmed = line.trim();
if trimmed.is_empty() {
continue;
}
return Some(
serde_json::from_str::<ChatResponse>(trimmed)
.map_err(|e| Error::StreamError(e.to_string())),
);
}
}
}
}
}