klieo-core 0.6.0

Core traits + runtime for the klieo agent framework.
Documentation
//! LLM-call retry policy shared by the blocking and streaming drivers.

use crate::error::{Error, LlmError};
use crate::llm::{ChatRequest, ChatResponse, ChunkStream, LlmClient};
use std::future::Future;
use std::time::Duration;
use tokio_util::sync::CancellationToken;
use tracing::debug;

/// Maximum number of LLM completion retries after the initial attempt.
/// Three retries means delays of 100 ms / 200 ms / 400 ms = 700 ms
/// total backoff — bounded under one second so a transient provider
/// hiccup recovers within a single user turn.
pub(crate) const MAX_LLM_RETRIES: u32 = 3;

/// Base backoff between LLM retries. Multiplied by `2^retry_index` to
/// produce 100 ms / 200 ms / 400 ms before retry 1 / 2 / 3.
pub(crate) const LLM_RETRY_BASE_DELAY: Duration = Duration::from_millis(100);

/// Invoke `llm.complete` with bounded exponential backoff on retryable
/// errors. See [`retry_call`] for the policy.
pub(crate) async fn complete_with_retry(
    llm: &dyn LlmClient,
    cancel: &CancellationToken,
    req: ChatRequest,
) -> Result<ChatResponse, Error> {
    retry_call(cancel, || llm.complete(req.clone())).await
}

/// Open a streaming chat with bounded exponential backoff on retryable
/// init errors. Mirrors [`complete_with_retry`] in policy and shape.
pub(crate) async fn stream_with_retry(
    llm: &dyn LlmClient,
    cancel: &CancellationToken,
    req: ChatRequest,
) -> Result<ChunkStream, Error> {
    retry_call(cancel, || llm.stream(req.clone())).await
}

/// Generic bounded-retry helper shared by [`complete_with_retry`] and
/// [`stream_with_retry`]. Encapsulates the exponential-backoff
/// schedule, `RateLimit.retry_after_secs` honour, attempt cap, and
/// cancellation integration so the two call sites stay byte-identical
/// in policy.
///
/// Policy:
/// - At most `MAX_LLM_RETRIES + 1` total attempts.
/// - Backoff between attempts is `LLM_RETRY_BASE_DELAY * 2^n`
///   (100 ms / 200 ms / 400 ms).
/// - If the error carries a server-suggested `retry_after_secs`
///   ([`LlmError::RateLimit`]), that delay is used in place of the
///   exponential schedule.
/// - Only errors for which [`Error::retryable`] returns true are
///   retried — `Unauthorized`, `BadRequest`, etc. propagate
///   immediately.
/// - Cooperative cancellation is honoured between sleeps: if `cancel`
///   fires while waiting, the function aborts with
///   [`Error::Cancelled`] without consuming the remaining budget.
/// - On exhausting all attempts, the final `LlmError` is returned
///   wrapped in `Error::Llm`, unchanged.
pub(crate) async fn retry_call<T, F, Fut>(cancel: &CancellationToken, mut op: F) -> Result<T, Error>
where
    F: FnMut() -> Fut,
    Fut: Future<Output = Result<T, LlmError>> + Send,
{
    let max_attempts = MAX_LLM_RETRIES + 1;
    let mut attempt: u32 = 0;
    loop {
        attempt += 1;
        if cancel.is_cancelled() {
            return Err(Error::Cancelled);
        }

        match op().await {
            Ok(v) => return Ok(v),
            Err(e) => {
                let last_attempt = attempt >= max_attempts;
                let err: Error = e.into();
                if last_attempt || !err.retryable() {
                    return Err(err);
                }

                // Honour server-suggested retry-after when present, else
                // exponential backoff: 100 ms, 200 ms, 400 ms before
                // retry 1, 2, 3 respectively.
                let delay = match &err {
                    Error::Llm(LlmError::RateLimit { retry_after_secs }) => {
                        Duration::from_secs(u64::from(*retry_after_secs))
                    }
                    _ => LLM_RETRY_BASE_DELAY * 2u32.pow(attempt - 1),
                };

                debug!(
                    attempt,
                    delay_ms = delay.as_millis() as u64,
                    reason = %err,
                    "llm call failed; retrying after backoff"
                );

                tokio::select! {
                    _ = cancel.cancelled() => return Err(Error::Cancelled),
                    _ = tokio::time::sleep(delay) => {}
                }
            }
        }
    }
}

#[cfg(test)]
mod tests {
    //! Unit tests for [`complete_with_retry`].
    //!
    //! Tests use `tokio::time::pause` + `tokio::time::advance` so the
    //! exponential-backoff delays do not slow the suite — virtual
    //! time makes the schedule (100 ms / 200 ms / 400 ms)
    //! instantaneous while still exercising the `tokio::time::sleep`
    //! path.
    use super::*;
    use crate::llm::{
        Capabilities, ChatRequest, ChatResponse, ChunkStream, Embedding, FinishReason, LlmClient,
        Message, Role,
    };
    use async_trait::async_trait;
    use std::sync::atomic::{AtomicU32, Ordering};

    /// A fake LLM that returns a configured `LlmError` for the first
    /// `fail_for` calls, then a `Stop` text response on every call
    /// after.
    struct FailingThenSucceedingLlm {
        fail_for: u32,
        calls: AtomicU32,
        caps: Capabilities,
        make_err: Box<dyn Fn() -> LlmError + Send + Sync>,
    }

    impl FailingThenSucceedingLlm {
        fn new(fail_for: u32, make_err: impl Fn() -> LlmError + Send + Sync + 'static) -> Self {
            Self {
                fail_for,
                calls: AtomicU32::new(0),
                caps: Capabilities::default(),
                make_err: Box::new(make_err),
            }
        }

        fn call_count(&self) -> u32 {
            self.calls.load(Ordering::SeqCst)
        }
    }

    #[async_trait]
    impl LlmClient for FailingThenSucceedingLlm {
        fn name(&self) -> &str {
            "failing-then-succeeding"
        }
        fn capabilities(&self) -> &Capabilities {
            &self.caps
        }
        async fn complete(&self, _req: ChatRequest) -> Result<ChatResponse, LlmError> {
            let n = self.calls.fetch_add(1, Ordering::SeqCst);
            if n < self.fail_for {
                return Err((self.make_err)());
            }
            Ok(ChatResponse {
                message: Message {
                    role: Role::Assistant,
                    content: "ok".into(),
                    tool_calls: vec![],
                    tool_call_id: None,
                },
                usage: Default::default(),
                finish_reason: FinishReason::Stop,
            })
        }
        async fn stream(&self, _req: ChatRequest) -> Result<ChunkStream, LlmError> {
            Err(LlmError::Unsupported("streaming".into()))
        }
        async fn embed(&self, _texts: &[String]) -> Result<Vec<Embedding>, LlmError> {
            Err(LlmError::Unsupported("embeddings".into()))
        }
    }

    fn req() -> ChatRequest {
        ChatRequest::new(vec![])
    }

    #[tokio::test(start_paused = true)]
    async fn retries_server_5xx_then_succeeds() {
        let llm = FailingThenSucceedingLlm::new(2, || LlmError::Server("503 unavailable".into()));
        let cancel = CancellationToken::new();

        let resp = complete_with_retry(&llm, &cancel, req())
            .await
            .expect("should succeed after retries");
        assert_eq!(resp.message.content, "ok");
        assert_eq!(
            llm.call_count(),
            3,
            "two failures plus one success = three attempts"
        );
    }

    #[tokio::test(start_paused = true)]
    async fn unauthorized_does_not_retry() {
        let llm = FailingThenSucceedingLlm::new(u32::MAX, || LlmError::Unauthorized);
        let cancel = CancellationToken::new();

        let err = complete_with_retry(&llm, &cancel, req())
            .await
            .expect_err("must error");
        assert!(
            matches!(err, Error::Llm(LlmError::Unauthorized)),
            "expected Llm(Unauthorized), got {err:?}"
        );
        assert_eq!(llm.call_count(), 1, "non-retryable: exactly one attempt");
    }

    #[tokio::test(start_paused = true)]
    async fn cancellation_during_backoff_aborts() {
        let llm = FailingThenSucceedingLlm::new(u32::MAX, || LlmError::Timeout);
        let cancel = CancellationToken::new();

        let cancel_clone = cancel.clone();
        let canceller = tokio::spawn(async move {
            tokio::task::yield_now().await;
            cancel_clone.cancel();
        });

        let err = complete_with_retry(&llm, &cancel, req())
            .await
            .expect_err("must error");
        canceller.await.unwrap();

        assert!(
            matches!(err, Error::Cancelled),
            "expected Cancelled, got {err:?}"
        );
        assert_eq!(llm.call_count(), 1);
    }

    #[tokio::test(start_paused = true)]
    async fn exhausts_attempts_and_propagates_last_error() {
        let llm = FailingThenSucceedingLlm::new(u32::MAX, || LlmError::Server("boom".into()));
        let cancel = CancellationToken::new();

        let err = complete_with_retry(&llm, &cancel, req())
            .await
            .expect_err("must error");
        assert!(
            matches!(err, Error::Llm(LlmError::Server(ref m)) if m == "boom"),
            "expected Llm(Server(\"boom\")), got {err:?}"
        );
        assert_eq!(llm.call_count(), MAX_LLM_RETRIES + 1);
    }

    #[tokio::test(start_paused = true)]
    async fn rate_limit_uses_retry_after_then_succeeds() {
        let llm = FailingThenSucceedingLlm::new(1, || LlmError::RateLimit {
            retry_after_secs: 1,
        });
        let cancel = CancellationToken::new();

        let resp = complete_with_retry(&llm, &cancel, req())
            .await
            .expect("should succeed after rate-limit retry");
        assert_eq!(resp.message.content, "ok");
        assert_eq!(llm.call_count(), 2);
    }
}