ggen-core 26.5.19

Core graph-aware code generation engine
Documentation
use crate::transport::error::{Result, TransportError};
use crate::transport::session::{ResumeCursor, SessionId};

use futures::stream::Stream;
use serde::{Deserialize, Serialize};
use std::pin::Pin;
use std::task::{Context, Poll};
use tokio::sync::mpsc;

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamMessage {
    pub sequence: u64,
    pub session_id: SessionId,
    pub payload: Vec<u8>,
    pub is_final: bool,
    pub metadata: Option<serde_json::Value>,
}

impl StreamMessage {
    pub fn new(sequence: u64, session_id: SessionId, payload: Vec<u8>) -> Self {
        Self {
            sequence,
            session_id,
            payload,
            is_final: false,
            metadata: None,
        }
    }

    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
        self.metadata = Some(metadata);
        self
    }

    pub fn mark_final(mut self) -> Self {
        self.is_final = true;
        self
    }
}

pub struct MessageStream {
    receiver: mpsc::Receiver<Result<StreamMessage>>,
    session_id: SessionId,
    current_position: u64,
}

impl MessageStream {
    pub fn new(receiver: mpsc::Receiver<Result<StreamMessage>>, session_id: SessionId) -> Self {
        Self {
            receiver,
            session_id,
            current_position: 0,
        }
    }

    pub fn get_cursor(&self) -> ResumeCursor {
        ResumeCursor::new(self.session_id.clone(), self.current_position)
    }

    pub async fn next(&mut self) -> Option<Result<StreamMessage>> {
        match self.receiver.recv().await {
            Some(Ok(msg)) => {
                self.current_position = msg.sequence;
                Some(Ok(msg))
            }
            Some(Err(e)) => Some(Err(e)),
            None => None,
        }
    }
}

impl Stream for MessageStream {
    type Item = Result<StreamMessage>;

    fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        self.receiver.poll_recv(cx)
    }
}

pub struct StreamBuilder {
    session_id: SessionId,
    buffer_size: usize,
    resume_from: Option<u64>,
}

impl StreamBuilder {
    pub fn new(session_id: SessionId) -> Self {
        Self {
            session_id,
            buffer_size: 100,
            resume_from: None,
        }
    }

    pub fn with_buffer_size(mut self, size: usize) -> Self {
        self.buffer_size = size;
        self
    }

    pub fn resume_from(mut self, position: u64) -> Self {
        self.resume_from = Some(position);
        self
    }

    pub fn build(self) -> (StreamSender, MessageStream) {
        let (tx, rx) = mpsc::channel(self.buffer_size);
        let sender = StreamSender {
            sender: tx,
            session_id: self.session_id.clone(),
            next_sequence: self.resume_from.unwrap_or(0),
        };
        let stream = MessageStream::new(rx, self.session_id);
        (sender, stream)
    }
}

pub struct StreamSender {
    sender: mpsc::Sender<Result<StreamMessage>>,
    session_id: SessionId,
    next_sequence: u64,
}

impl StreamSender {
    pub async fn send(&mut self, payload: Vec<u8>) -> Result<()> {
        let msg = StreamMessage::new(self.next_sequence, self.session_id.clone(), payload);
        self.next_sequence += 1;
        self.sender
            .send(Ok(msg))
            .await
            .map_err(|_| TransportError::StreamError("Channel closed".to_string()))
    }

    pub async fn send_with_metadata(
        &mut self, payload: Vec<u8>, metadata: serde_json::Value,
    ) -> Result<()> {
        let msg = StreamMessage::new(self.next_sequence, self.session_id.clone(), payload)
            .with_metadata(metadata);
        self.next_sequence += 1;
        self.sender
            .send(Ok(msg))
            .await
            .map_err(|_| TransportError::StreamError("Channel closed".to_string()))
    }

    pub async fn send_final(&mut self, payload: Vec<u8>) -> Result<()> {
        let msg =
            StreamMessage::new(self.next_sequence, self.session_id.clone(), payload).mark_final();
        self.next_sequence += 1;
        self.sender
            .send(Ok(msg))
            .await
            .map_err(|_| TransportError::StreamError("Channel closed".to_string()))
    }

    pub async fn send_error(&self, error: TransportError) -> Result<()> {
        self.sender
            .send(Err(error))
            .await
            .map_err(|_| TransportError::StreamError("Channel closed".to_string()))
    }

    pub fn get_current_sequence(&self) -> u64 {
        self.next_sequence
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StreamControl {
    pub session_id: SessionId,
    pub action: StreamAction,
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum StreamAction {
    Pause,
    Resume { from_position: Option<u64> },
    Cancel,
    Acknowledge { sequence: u64 },
}

impl StreamControl {
    pub fn pause(session_id: SessionId) -> Self {
        Self {
            session_id,
            action: StreamAction::Pause,
        }
    }

    pub fn resume(session_id: SessionId, from_position: Option<u64>) -> Self {
        Self {
            session_id,
            action: StreamAction::Resume { from_position },
        }
    }

    pub fn cancel(session_id: SessionId) -> Self {
        Self {
            session_id,
            action: StreamAction::Cancel,
        }
    }

    pub fn acknowledge(session_id: SessionId, sequence: u64) -> Self {
        Self {
            session_id,
            action: StreamAction::Acknowledge { sequence },
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_stream_basic() {
        let session_id = SessionId::new();
        let builder = StreamBuilder::new(session_id);
        let (mut sender, mut stream) = builder.build();

        sender.send(Vec::from("test")).await.unwrap();
        let msg = stream.next().await.unwrap().unwrap();
        assert_eq!(msg.sequence, 0);
        assert_eq!(msg.payload, Vec::from("test"));
    }

    #[tokio::test]
    async fn test_stream_sequence() {
        let session_id = SessionId::new();
        let builder = StreamBuilder::new(session_id);
        let (mut sender, mut stream) = builder.build();

        sender.send(Vec::from("1")).await.unwrap();
        sender.send(Vec::from("2")).await.unwrap();
        sender.send(Vec::from("3")).await.unwrap();

        let msg1 = stream.next().await.unwrap().unwrap();
        let msg2 = stream.next().await.unwrap().unwrap();
        let msg3 = stream.next().await.unwrap().unwrap();

        assert_eq!(msg1.sequence, 0);
        assert_eq!(msg2.sequence, 1);
        assert_eq!(msg3.sequence, 2);
    }

    #[tokio::test]
    async fn test_stream_final() {
        let session_id = SessionId::new();
        let builder = StreamBuilder::new(session_id);
        let (mut sender, mut stream) = builder.build();

        sender.send_final(Vec::from("final")).await.unwrap();
        let msg = stream.next().await.unwrap().unwrap();
        assert!(msg.is_final);
    }

    #[tokio::test]
    async fn test_stream_resume() {
        let session_id = SessionId::new();
        let builder = StreamBuilder::new(session_id).resume_from(10);
        let (mut sender, mut stream) = builder.build();

        sender.send(Vec::from("test")).await.unwrap();
        let msg = stream.next().await.unwrap().unwrap();
        assert_eq!(msg.sequence, 10);
    }
}