sipp-rs 0.1.0

Unified Rust library for extensible Sipp inference
use std::sync::{Arc, Mutex};

use crate::client::{
    EndpointRef, SippChatRequest, SippEmbedRequest, SippEmbeddingResponse, SippEmbeddingRun,
    SippQueryRequest, SippResponseMetadata, SippTextResponse, SippTextRun, SippTokenBatches,
};
use crate::core::{FinishReason, TokenBatch, TokenEmissionStats};
use futures_util::{stream, StreamExt};

use super::{
    AdmissionController, AdmissionPermit, Authorizer, GatewayExecutor, GatewayPipeline, Operation,
    TargetResolver,
};
use crate::gateway_core::{
    GatewayCancellationReason, GatewayError, GatewayErrorKind, GatewayRequestContext,
    GatewayResult, GatewayStreamEvent,
};

#[tokio::test]
async fn pipeline_orders_policy_before_execution_and_releases_permit() {
    let events = Arc::new(Mutex::new(Vec::new()));
    let drops = Arc::new(Mutex::new(0));
    let pipeline = pipeline(events.clone(), drops.clone(), false);
    let response = pipeline
        .query(
            &GatewayRequestContext::new(Some("request".to_string())),
            "public",
            SippQueryRequest::default(),
        )
        .await
        .expect("query");

    assert_eq!(response.text, "ok");
    assert_eq!(
        *events.lock().expect("events"),
        ["resolve", "authorize", "admit", "execute"]
    );
    assert_eq!(*drops.lock().expect("drops"), 1);
}

#[tokio::test]
async fn authorization_stops_admission_and_execution() {
    let events = Arc::new(Mutex::new(Vec::new()));
    let drops = Arc::new(Mutex::new(0));
    let pipeline = pipeline(events.clone(), drops, true);
    let error = pipeline
        .chat(
            &GatewayRequestContext::default(),
            "public",
            SippChatRequest::default(),
        )
        .await
        .expect_err("authorization");

    assert_eq!(error.kind, GatewayErrorKind::Authorization);
    assert_eq!(*events.lock().expect("events"), ["resolve", "authorize"]);
}

#[tokio::test]
async fn streaming_holds_admission_until_terminal_event() {
    let events = Arc::new(Mutex::new(Vec::new()));
    let drops = Arc::new(Mutex::new(0));
    let pipeline = pipeline(events, drops.clone(), false);
    let mut stream = pipeline
        .stream_query(
            &GatewayRequestContext::default(),
            "public",
            SippQueryRequest::default(),
        )
        .expect("stream");

    assert_eq!(*drops.lock().expect("drops"), 0);
    assert!(matches!(
        stream.next().await,
        Some(Ok(GatewayStreamEvent::TokenBatch(_)))
    ));
    while stream.next().await.is_some() {}
    assert_eq!(*drops.lock().expect("drops"), 1);
}

#[tokio::test]
async fn cancellation_reaches_the_active_client_run() {
    let context = GatewayRequestContext::default();
    let pipeline = GatewayPipeline::new(
        Arc::new(Resolver {
            events: Arc::new(Mutex::new(Vec::new())),
        }),
        Arc::new(Policy {
            events: Arc::new(Mutex::new(Vec::new())),
            deny: false,
        }),
        Arc::new(Admission {
            events: Arc::new(Mutex::new(Vec::new())),
            drops: Arc::new(Mutex::new(0)),
        }),
        Arc::new(PendingExecutor),
    );
    let task_context = context.clone();
    let task = tokio::spawn(async move {
        pipeline
            .query(&task_context, "public", SippQueryRequest::default())
            .await
    });
    tokio::task::yield_now().await;
    context
        .cancellation
        .cancel(GatewayCancellationReason::CallerCancelled);
    let error = task.await.expect("task").expect_err("cancelled");
    assert_eq!(error.kind, GatewayErrorKind::Cancelled);
}

fn pipeline(
    events: Arc<Mutex<Vec<&'static str>>>,
    drops: Arc<Mutex<u32>>,
    deny: bool,
) -> GatewayPipeline {
    GatewayPipeline::new(
        Arc::new(Resolver {
            events: events.clone(),
        }),
        Arc::new(Policy {
            events: events.clone(),
            deny,
        }),
        Arc::new(Admission {
            events: events.clone(),
            drops,
        }),
        Arc::new(Executor { events }),
    )
}

struct Resolver {
    events: Arc<Mutex<Vec<&'static str>>>,
}

impl TargetResolver for Resolver {
    fn resolve(
        &self,
        _context: &GatewayRequestContext,
        target: &str,
        _operation: Operation,
    ) -> GatewayResult<EndpointRef> {
        self.events.lock().expect("events").push("resolve");
        if target == "public" {
            Ok(EndpointRef::gateway("resolved"))
        } else {
            Err(GatewayError::new(
                GatewayErrorKind::Resolution,
                "missing target",
            ))
        }
    }
}

struct Policy {
    events: Arc<Mutex<Vec<&'static str>>>,
    deny: bool,
}

impl Authorizer for Policy {
    fn authorize(
        &self,
        _context: &GatewayRequestContext,
        _target: &str,
        _endpoint: &EndpointRef,
        _operation: Operation,
    ) -> GatewayResult<()> {
        self.events.lock().expect("events").push("authorize");
        if self.deny {
            Err(GatewayError::new(GatewayErrorKind::Authorization, "denied"))
        } else {
            Ok(())
        }
    }
}

struct Admission {
    events: Arc<Mutex<Vec<&'static str>>>,
    drops: Arc<Mutex<u32>>,
}

impl AdmissionController for Admission {
    fn acquire(
        &self,
        _context: &GatewayRequestContext,
        _target: &str,
        _endpoint: &EndpointRef,
        _operation: Operation,
    ) -> GatewayResult<Box<dyn AdmissionPermit>> {
        self.events.lock().expect("events").push("admit");
        Ok(Box::new(Permit {
            drops: self.drops.clone(),
        }))
    }
}

struct Permit {
    drops: Arc<Mutex<u32>>,
}

impl Drop for Permit {
    fn drop(&mut self) {
        *self.drops.lock().expect("drops") += 1;
    }
}

struct Executor {
    events: Arc<Mutex<Vec<&'static str>>>,
}

impl GatewayExecutor for Executor {
    fn query(&self, _context: &GatewayRequestContext, _request: SippQueryRequest) -> SippTextRun {
        self.events.lock().expect("events").push("execute");
        let batch = TokenBatch {
            request_id: "request".to_string(),
            stream_id: 0,
            sequence_start: 0,
            text: "ok".to_string(),
            frame_count: 1,
            byte_count: 2,
            stats: TokenEmissionStats::default(),
        };
        SippTextRun::from_parts(
            SippTokenBatches::from_stream(Box::pin(stream::iter([batch]))),
            Box::pin(async { Ok(text_response()) }),
        )
    }

    fn chat(&self, context: &GatewayRequestContext, request: SippChatRequest) -> SippTextRun {
        self.query(
            context,
            SippQueryRequest {
                endpoint: request.endpoint,
                ..SippQueryRequest::default()
            },
        )
    }

    fn embed(
        &self,
        _context: &GatewayRequestContext,
        _request: SippEmbedRequest,
    ) -> SippEmbeddingRun {
        SippEmbeddingRun::from_response(Box::pin(async {
            Ok(SippEmbeddingResponse {
                endpoint: EndpointRef::gateway("resolved"),
                values: vec![1.0],
                usage: None,
                local_stats: None,
                pooling: None,
                normalized: None,
                metadata: SippResponseMetadata::default(),
            })
        }))
    }
}

struct PendingExecutor;

impl GatewayExecutor for PendingExecutor {
    fn query(&self, _context: &GatewayRequestContext, _request: SippQueryRequest) -> SippTextRun {
        SippTextRun::from_response(Box::pin(futures_util::future::pending()))
    }

    fn chat(&self, context: &GatewayRequestContext, _request: SippChatRequest) -> SippTextRun {
        self.query(context, SippQueryRequest::default())
    }

    fn embed(
        &self,
        _context: &GatewayRequestContext,
        _request: SippEmbedRequest,
    ) -> SippEmbeddingRun {
        SippEmbeddingRun::from_response(Box::pin(futures_util::future::pending()))
    }
}

fn text_response() -> SippTextResponse {
    SippTextResponse {
        endpoint: EndpointRef::gateway("resolved"),
        text: "ok".to_string(),
        finish_reason: FinishReason::Stop,
        usage: None,
        local_stats: None,
        metadata: SippResponseMetadata::default(),
    }
}