use std::sync::Arc;
use std::time::Duration;
use dynamo_llm::protocols::common::preprocessor::PreprocessedRequest;
use dynamo_llm::protocols::common::{FinishReason, OutputOptions, SamplingOptions, StopConditions};
use dynamo_runtime::engine::AsyncEngineContext;
use dynamo_runtime::pipeline::{AsyncEngineContextProvider, Context};
use futures::StreamExt;
use crate::engine::{GenerateContext, LLMEngine};
use ConformanceFailure::*;
const DEFAULT_CANCEL_DEADLINE: Duration = Duration::from_secs(2);
pub fn mock_context() -> Arc<dyn AsyncEngineContext> {
Context::<()>::new(()).context()
}
pub fn cancelling_context(after: Duration) -> Arc<dyn AsyncEngineContext> {
let ctx = Context::<()>::new(()).context();
let ctx2 = ctx.clone();
tokio::spawn(async move {
tokio::time::sleep(after).await;
ctx2.stop_generating();
});
ctx
}
#[derive(Debug)]
pub enum ConformanceFailure {
StartFailed(String),
EmptyModelInConfig,
GenerateFailed(String),
NoChunksYielded,
ChunkAfterTerminal,
NoTerminalChunk,
StreamYieldedError(String),
ConcurrentGenerateFailed(String),
CancellationNotObserved { after: Duration },
CancellationIgnored,
CleanupFailed(String),
SecondCleanupFailed(String),
CleanupWithoutStartFailed(String),
}
impl std::fmt::Display for ConformanceFailure {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
StartFailed(m) => write!(f, "start() failed: {m}"),
EmptyModelInConfig => write!(f, "EngineConfig.model is empty"),
GenerateFailed(m) => write!(f, "generate() failed: {m}"),
NoChunksYielded => write!(f, "generate() stream yielded no chunks"),
ChunkAfterTerminal => write!(f, "chunk yielded after terminal chunk"),
NoTerminalChunk => write!(f, "stream ended without a terminal chunk"),
StreamYieldedError(m) => write!(f, "engine stream yielded Err: {m}"),
ConcurrentGenerateFailed(m) => {
write!(f, "concurrent generate() calls failed: {m}")
}
CancellationNotObserved { after } => write!(
f,
"stream did not terminate within {after:?} after cancellation"
),
CancellationIgnored => write!(
f,
"stream terminated but terminal chunk's finish_reason was not Cancelled \
(engine must emit FinishReason::Cancelled when it observes cancellation)"
),
CleanupFailed(m) => write!(f, "cleanup() failed: {m}"),
SecondCleanupFailed(m) => {
write!(f, "second cleanup() call failed (must be idempotent): {m}")
}
CleanupWithoutStartFailed(m) => write!(
f,
"cleanup() failed on a never-started engine: {m} \
(Worker calls cleanup() after start() raises, so engines must \
be null-safe against partial / no allocation)"
),
}
}
}
impl std::error::Error for ConformanceFailure {}
pub async fn run_conformance<E, F>(mut factory: F) -> Result<(), ConformanceFailure>
where
E: LLMEngine,
F: FnMut() -> E,
{
let engine = factory();
let config = engine
.start(0)
.await
.map_err(|e| StartFailed(e.to_string()))?;
if config.model.is_empty() {
return Err(EmptyModelInConfig);
}
check_single_generate(&engine, &config.model).await?;
check_concurrent_generates(&engine, &config.model).await?;
check_cancellation(&engine, &config.model, DEFAULT_CANCEL_DEADLINE).await?;
engine
.cleanup()
.await
.map_err(|e| CleanupFailed(e.to_string()))?;
engine
.cleanup()
.await
.map_err(|e| SecondCleanupFailed(e.to_string()))?;
let fresh = factory();
fresh
.cleanup()
.await
.map_err(|e| CleanupWithoutStartFailed(e.to_string()))?;
Ok(())
}
fn request(model: &str) -> PreprocessedRequest {
request_with_max_tokens(model, None)
}
fn request_with_max_tokens(model: &str, max_tokens: Option<u32>) -> PreprocessedRequest {
PreprocessedRequest::builder()
.model(model.to_string())
.token_ids(vec![1, 2, 3])
.stop_conditions(StopConditions {
max_tokens,
..Default::default()
})
.sampling_options(SamplingOptions::default())
.output_options(OutputOptions::default())
.build()
.expect("build request")
}
async fn check_single_generate<E: LLMEngine>(
engine: &E,
model: &str,
) -> Result<(), ConformanceFailure> {
let ctx = mock_context();
let stream = engine
.generate(request(model), GenerateContext::new(ctx, None))
.await
.map_err(|e| GenerateFailed(e.to_string()))?;
let items: Vec<_> = stream.collect().await;
if items.is_empty() {
return Err(NoChunksYielded);
}
let mut chunks = Vec::with_capacity(items.len());
for item in items {
match item {
Ok(c) => chunks.push(c),
Err(e) => return Err(StreamYieldedError(e.to_string())),
}
}
let mut terminal_idx = None;
for (i, c) in chunks.iter().enumerate() {
if c.finish_reason.is_some() {
if terminal_idx.is_some() {
return Err(ChunkAfterTerminal);
}
terminal_idx = Some(i);
}
}
match terminal_idx {
Some(i) if i == chunks.len() - 1 => Ok(()),
Some(_) => Err(ChunkAfterTerminal),
None => Err(NoTerminalChunk),
}
}
async fn check_concurrent_generates<E: LLMEngine>(
engine: &E,
model: &str,
) -> Result<(), ConformanceFailure> {
const CONCURRENT: usize = 8;
let futs = (0..CONCURRENT).map(|_| async {
let ctx = mock_context();
let stream = engine
.generate(request(model), GenerateContext::new(ctx, None))
.await
.map_err(|e| ConcurrentGenerateFailed(e.to_string()))?;
let n = stream.count().await;
if n == 0 {
Err(ConcurrentGenerateFailed("stream was empty".to_string()))
} else {
Ok(())
}
});
for result in futures::future::join_all(futs).await {
result?;
}
Ok(())
}
async fn check_cancellation<E: LLMEngine>(
engine: &E,
model: &str,
deadline: Duration,
) -> Result<(), ConformanceFailure> {
const LONG_MAX_TOKENS: u32 = 10_000;
let ctx = mock_context();
let stream = engine
.generate(
request_with_max_tokens(model, Some(LONG_MAX_TOKENS)),
GenerateContext::new(ctx.clone(), None),
)
.await
.map_err(|e| GenerateFailed(e.to_string()))?;
ctx.stop_generating();
let items = tokio::time::timeout(deadline, async {
let mut s = stream;
let mut out = Vec::new();
while let Some(c) = s.next().await {
out.push(c);
}
out
})
.await
.map_err(|_| CancellationNotObserved { after: deadline })?;
match items.last() {
Some(Ok(c)) if matches!(c.finish_reason, Some(FinishReason::Cancelled)) => Ok(()),
Some(Ok(_)) => Err(CancellationIgnored),
Some(Err(e)) => Err(StreamYieldedError(e.to_string())),
None => Err(NoChunksYielded),
}
}