// @harn-entrypoint-category llm.stdlib
//
// std/llm/handlers — composable middleware around the (call) -> envelope
// caller seam wired into agent_loop. Each `with_*` returns a NEW caller
// closure that wraps `next`. Compose left-to-right via `compose`.
//
// Caller shape:
// call = {prompt, system, opts, turn: {iteration, session_id, attempt}}
// result = {ok: true, value: <llm dict>}
// | {ok: false, status: <reserved>, error?, retryable?}
//
// Reserved statuses: "budget_exhausted", "transport_error", "caller_aborted",
// "caller_skipped", "exception", "schema_validation", "rate_limited",
// "timeout", "network", "provider_5xx", "stream_interrupt",
// "context_window_exceeded", "auth", "policy_blocked", "circuit_open".
import { agent_emit_event } from "std/agent/state"
import { cache_get, cache_put } from "std/cache"
/**
* default_llm_caller returns a closure with the canonical (call) -> result
* shape. It mirrors the built-in __default_invoke_llm body in std/agent/loop.
* Use as the bottom of a middleware composition:
*
* let caller = with_retry(default_llm_caller(), {...})
*/
pub fn default_llm_caller() {
return { call ->
let result = try {
llm_call(call.prompt, call.system, call.opts)
}
if !is_err(result) {
return {ok: true, value: unwrap(result)}
}
let err = unwrap_err(result)
let reason = if type_of(err) == "dict" {
err?.reason ?? ""
} else {
""
}
if reason == "budget_exceeded" {
return {ok: false, status: "budget_exhausted"}
}
return {ok: false, status: "exception", error: err}
}
}
// -------------------------------------------------------------------------------------------------
// shared helpers
// -------------------------------------------------------------------------------------------------
fn __opts_dict(opts) {
if type_of(opts) == "dict" {
return opts
}
return {}
}
var __llm_handler_circuits = []
fn __llm_handler_is_callable(value) -> bool {
let kind = type_of(value)
return kind == "function" || kind == "closure" || kind == "fn"
}
fn __circuit_name_for(call) -> string {
let opts = call?.opts ?? {}
let provider = to_string(opts?.provider ?? "<unset>")
let model = to_string(opts?.model ?? "<unset>")
return "llm:" + provider + ":" + model
}
fn __ensure_llm_handler_circuit(name, threshold, reset_ms) {
if !contains(__llm_handler_circuits, name) {
circuit_breaker(name, threshold, reset_ms)
__llm_handler_circuits = __llm_handler_circuits + [name]
}
}
fn __circuit_open_error(call, name) {
let opts = call?.opts ?? {}
return {
kind: "terminal",
reason: "circuit_open",
category: "circuit_open",
message: "circuit open: " + name,
provider: opts?.provider ?? "<unset>",
model: opts?.model ?? "<unset>",
circuit: name,
}
}
fn __llm_handler_should_skip_cache(prompt, system, opts) -> bool {
let predicate = opts?.skip_when
if predicate == nil {
return opts?.tools != nil
}
if __llm_handler_is_callable(predicate) {
return predicate({prompt: prompt, system: system, options: opts}) ? true : false
}
return predicate ? true : false
}
fn __llm_handler_cache_options(opts) -> dict {
var ttl = opts?.ttl
if ttl == nil && opts?.ttl_seconds == nil && opts?.max_age_seconds == nil {
ttl = "10m"
}
return {
store: opts?.store ?? "llm.with_cache",
backend: opts?.backend,
namespace: opts?.namespace,
name: opts?.name,
path: opts?.path,
cache_dir: opts?.cache_dir,
ttl: ttl,
ttl_seconds: opts?.ttl_seconds,
max_age_seconds: opts?.max_age_seconds,
max_entries: opts?.max_entries ?? 256,
}
}
fn __safe_invoke(next, call) {
let envelope = try {
next(call)
}
if is_err(envelope) {
return {ok: false, status: "exception", error: unwrap_err(envelope)}
}
let value = unwrap(envelope)
if type_of(value) != "dict" {
return {ok: false, status: "exception", error: "caller returned non-dict"}
}
return value
}
fn __emit_event(session_id, name, payload) {
if to_string(session_id) == "" {
return
}
let _ = try {
agent_emit_event(session_id, name, payload)
}
}
fn __retry_default_predicate(envelope) {
if type_of(envelope) != "dict" {
return false
}
if envelope?.ok ?? false {
return false
}
let status = to_string(envelope?.status ?? "")
let retryable_statuses = ["transient", "rate_limited", "timeout", "exception", "network", "provider_5xx", "stream_interrupt"]
let never_retry = [
"schema_validation",
"auth",
"budget_exhausted",
"context_window_exceeded",
"policy_blocked",
"caller_aborted",
"caller_skipped",
"circuit_open",
]
if contains(never_retry, status) {
return false
}
if contains(retryable_statuses, status) {
return true
}
// explicit retryable hint wins for unknown statuses
if envelope?.retryable ?? false {
return true
}
return false
}
fn __retry_after_ms(envelope) {
if type_of(envelope) != "dict" {
return 0
}
let err = envelope?.error
if type_of(err) == "dict" {
let ra = err?.retry_after_ms
if ra != nil {
return to_int(ra)
}
let headers = err?.headers
if type_of(headers) == "dict" {
for k in headers.keys() {
if lowercase(to_string(k)) == "retry-after" {
let v = to_string(headers[k])
if v != "" {
return to_int(v) * 1000
}
}
}
}
}
return 0
}
fn __backoff_delay(attempt, base_ms, max_ms, backoff, jitter) {
if attempt < 1 {
return 0
}
var delay = base_ms
if backoff == "linear" {
delay = base_ms * attempt
} else {
// exponential (default) and jittered share the same base curve
var factor = 1
var i = 1
while i < attempt {
factor = factor * 2
i = i + 1
}
delay = base_ms * factor
}
if delay > max_ms {
delay = max_ms
}
if backoff == "jittered" || jitter == "full" {
let r = random()
delay = to_int(r * to_float(delay))
} else if jitter == "equal" {
let half = delay / 2
let r = random()
delay = half + to_int(r * to_float(half))
}
if delay < 0 {
delay = 0
}
return delay
}
// -------------------------------------------------------------------------------------------------
// with_retry
// -------------------------------------------------------------------------------------------------
/**
* with_retry(next, opts) -> caller
*
* Wraps `next` with bounded retry. Default opts:
* {max_attempts: 3, base_ms: 250, max_ms: 8000,
* backoff: "exponential", jitter: "full", honor_retry_after: true}
*
* The default predicate retries on these statuses:
* transient, rate_limited, timeout, exception, network,
* provider_5xx, stream_interrupt
*
* The default predicate NEVER retries:
* schema_validation, auth, budget_exhausted,
* context_window_exceeded, policy_blocked
*
* `opts.predicate(envelope) -> bool` overrides the default.
* Honors `error.retry_after_ms` and case-insensitive `Retry-After` header
* when honor_retry_after is true.
*
* Returns the LAST envelope unchanged plus `retries_attempted: N`.
* Never throws — raw throws from `next` become {ok: false, status: "exception"}.
*/
pub fn with_retry(next, opts = nil) {
let cfg = __opts_dict(opts)
let max_attempts = cfg?.max_attempts ?? 3
let base_ms = cfg?.base_ms ?? 250
let max_ms = cfg?.max_ms ?? 8000
let backoff = cfg?.backoff ?? "exponential"
let jitter = cfg?.jitter ?? "full"
let honor_retry_after = cfg?.honor_retry_after ?? true
let predicate = cfg?.predicate
return { call ->
let base_turn = call?.turn ?? {iteration: 0, session_id: "", attempt: 1}
var attempt = 1
var last_envelope = {ok: false, status: "exception", error: "with_retry: no attempts"}
while attempt <= max_attempts {
let attempt_call = call + {turn: base_turn + {attempt: attempt}}
let envelope = __safe_invoke(next, attempt_call)
last_envelope = envelope
let should_retry = if predicate != nil {
let r = try {
predicate(envelope)
}
if is_err(r) {
false
} else {
unwrap(r)
}
} else {
__retry_default_predicate(envelope)
}
if !should_retry || attempt >= max_attempts {
return envelope + {retries_attempted: attempt - 1}
}
let header_delay = if honor_retry_after {
__retry_after_ms(envelope)
} else {
0
}
let backoff_delay = __backoff_delay(attempt, base_ms, max_ms, backoff, jitter)
let delay = if header_delay > backoff_delay {
header_delay
} else {
backoff_delay
}
if delay > 0 {
sleep_ms(delay)
}
attempt = attempt + 1
}
return last_envelope + {retries_attempted: max_attempts - 1}
}
}
// -------------------------------------------------------------------------------------------------
// with_fallback
// -------------------------------------------------------------------------------------------------
/**
* with_fallback(callers) -> caller
*
* Try each caller in `callers` (a list of caller closures) in order; advance
* on {ok: false}. Emits an `llm_fallback_attempt` event per attempt when
* call.turn.session_id is non-empty.
*
* On success: result + {fallback_index, fallback_total}.
* On all-fail: last envelope + {fallback_total}.
*/
pub fn with_fallback(callers) {
let total = if type_of(callers) == "list" {
len(callers)
} else {
0
}
return { call ->
if total == 0 {
return {ok: false, status: "caller_skipped", error: "with_fallback: empty caller list"}
}
var last_envelope = {ok: false, status: "caller_skipped"}
var idx = 0
let session_id = to_string(call?.turn?.session_id ?? "")
while idx < total {
let inner = callers[idx]
let envelope = __safe_invoke(inner, call)
last_envelope = envelope
__emit_event(
session_id,
"llm_fallback_attempt",
{
fallback_index: idx,
fallback_total: total,
ok: envelope?.ok ?? false,
status: to_string(envelope?.status ?? ""),
},
)
if envelope?.ok ?? false {
return envelope + {fallback_index: idx, fallback_total: total}
}
idx = idx + 1
}
return last_envelope + {fallback_total: total}
}
}
fn __envelope_text(env) {
if !(env?.ok ?? false) {
return ""
}
return to_string(env?.value?.text ?? "")
}
// -------------------------------------------------------------------------------------------------
// with_shadow
// -------------------------------------------------------------------------------------------------
/**
* with_shadow(primary, shadow, opts) -> caller
*
* Run primary and shadow concurrently via `parallel each` over a 2-element
* list. Always returns the PRIMARY envelope. Options:
* - sampler: closure(call) -> bool (default: always true)
* - on_diff: closure(primary_env, shadow_env) -> nil
* - diff_when: "any" | "ok_only" (default "any")
*
* Emits `llm_shadow_diff` event with diff metadata when both envelopes
* differ and call.turn.session_id is non-empty.
*/
pub fn with_shadow(primary, shadow, opts = nil) {
let cfg = __opts_dict(opts)
let sampler = cfg?.sampler
let on_diff = cfg?.on_diff
let diff_when = cfg?.diff_when ?? "any"
return { call ->
let sample = if sampler != nil {
let r = try {
sampler(call)
}
if is_err(r) {
false
} else {
let v = unwrap(r)
if v == nil {
false
} else {
!!v
}
}
} else {
true
}
if !sample {
return __safe_invoke(primary, call)
}
let pair = [primary, shadow]
let results = parallel each pair { c ->
__safe_invoke(c, call)
}
let primary_env = results[0]
let shadow_env = results[1]
let primary_text = __envelope_text(primary_env)
let shadow_text = __envelope_text(shadow_env)
let both_ok = primary_env?.ok ?? false && shadow_env?.ok ?? false
let differs = primary_text != shadow_text
let should_compare = if diff_when == "ok_only" {
both_ok && differs
} else {
differs
}
if should_compare {
__emit_event(
to_string(call?.turn?.session_id ?? ""),
"llm_shadow_diff",
{
primary_ok: primary_env?.ok ?? false,
shadow_ok: shadow_env?.ok ?? false,
primary_status: to_string(primary_env?.status ?? "ok"),
shadow_status: to_string(shadow_env?.status ?? "ok"),
primary_len: len(primary_text),
shadow_len: len(shadow_text),
},
)
if on_diff != nil {
let _ = try {
on_diff(primary_env, shadow_env)
}
}
}
return primary_env
}
}
// -------------------------------------------------------------------------------------------------
// with_prompt_rewrite
// -------------------------------------------------------------------------------------------------
/**
* with_prompt_rewrite(next, rewriter) -> caller
*
* `rewriter(prompt, system, opts) -> {prompt?, system?, opts?}` (any subset).
* Missing keys fall back to the original. Then delegates to `next` with the
* rewritten call. Preserves `call.turn`.
*/
pub fn with_prompt_rewrite(next, rewriter) {
return { call ->
let rewritten = try {
rewriter(call?.prompt, call?.system, call?.opts)
}
if is_err(rewritten) {
return {ok: false, status: "exception", error: unwrap_err(rewritten)}
}
let patch = unwrap(rewritten)
let patch_dict = if type_of(patch) == "dict" {
patch
} else {
{}
}
let new_prompt = if patch_dict?.prompt != nil {
patch_dict.prompt
} else {
call?.prompt
}
let new_system = if patch_dict?.system != nil {
patch_dict.system
} else {
call?.system
}
let new_opts = if patch_dict?.opts != nil {
patch_dict.opts
} else {
call?.opts
}
let new_call = {
prompt: new_prompt,
system: new_system,
opts: new_opts,
turn: call?.turn ?? {iteration: 0, session_id: "", attempt: 1},
}
return __safe_invoke(next, new_call)
}
}
// -------------------------------------------------------------------------------------------------
// with_logging
// -------------------------------------------------------------------------------------------------
/**
* with_logging(next, opts) -> caller
*
* Wraps `next` with structured logging per call. Options:
* - level: "debug" | "info" | "warn" (default "info")
* - include_prompt: bool (default false; PII)
* - sink: closure(record) -> nil (optional callback)
*
* Emits `llm_call_log` event when call.turn.session_id is non-empty.
*/
pub fn with_logging(next, opts = nil) {
let cfg = __opts_dict(opts)
let level = cfg?.level ?? "info"
let include_prompt = cfg?.include_prompt ?? false
let sink = cfg?.sink
return { call ->
let start_ms = now_ms()
let envelope = __safe_invoke(next, call)
let call_opts = __opts_dict(call?.opts)
let status = if envelope?.ok ?? false {
"ok"
} else {
to_string(envelope?.status ?? "unknown")
}
var record = {
event: "llm_call_log",
level: level,
latency_ms: now_ms() - start_ms,
model: to_string(call_opts?.model ?? ""),
provider: to_string(call_opts?.provider ?? ""),
status: status,
iteration: to_int(call?.turn?.iteration ?? 0),
attempt: to_int(call?.turn?.attempt ?? 1),
}
if include_prompt {
record = record + {prompt: to_string(call?.prompt ?? ""), system: to_string(call?.system ?? "")}
}
__emit_event(to_string(call?.turn?.session_id ?? ""), "llm_call_log", record)
if sink != nil {
let _ = try {
sink(record)
}
}
return envelope
}
}
fn __budget_check(limit_name, max_value, observed, on_exceed) {
if max_value == nil || observed < to_int(max_value) {
return nil
}
if on_exceed == "throw" {
throw "with_budget: " + limit_name + " exceeded"
}
return {
ok: false,
status: "budget_exhausted",
error: {limit: limit_name, value: max_value, observed: observed},
}
}
// -------------------------------------------------------------------------------------------------
// with_budget
// -------------------------------------------------------------------------------------------------
/**
* with_budget(next, opts) -> caller
*
* Tracks per-caller-instance usage across calls.
*
* IMPORTANT: Harn closures capture by VALUE, so we cannot persist a free-form
* dict across calls. Counters are instead held as `atomic` handles, whose
* underlying Arc<AtomicI64> is shared across closure invocations. This means
* with_budget tracks INTEGER counters (tokens, calls, micro-cents) only.
*
* Options (all optional):
* - max_total_tokens: int (input + output)
* - max_input_tokens: int
* - max_output_tokens: int
* - max_calls: int
* - max_cost_usd: float (compared against atomic micro-cents counter)
* - on_exceed: "throw" | "short_circuit" (default "short_circuit")
*
* Cost via the unstable `pricing_per_1k_for(provider, model)` builtin. If
* that builtin is not yet exposed to Harn (currently Rust-only), with_budget
* silently skips cost accounting — only token/call limits are enforced.
*
* On exceed:
* - "short_circuit": returns {ok: false, status: "budget_exhausted"}
* - "throw": throws (propagates to caller)
*/
pub fn with_budget(next, opts = nil) {
let cfg = __opts_dict(opts)
let max_total_tokens = cfg?.max_total_tokens
let max_input_tokens = cfg?.max_input_tokens
let max_output_tokens = cfg?.max_output_tokens
let max_calls = cfg?.max_calls
let on_exceed = cfg?.on_exceed ?? "short_circuit"
let total_in = atomic(0)
let total_out = atomic(0)
let total_calls = atomic(0)
return { call ->
let calls_so_far = atomic_get(total_calls)
let in_so_far = atomic_get(total_in)
let out_so_far = atomic_get(total_out)
let exhaustion = __budget_check("max_calls", max_calls, calls_so_far, on_exceed)
?? __budget_check("max_input_tokens", max_input_tokens, in_so_far, on_exceed)
?? __budget_check("max_output_tokens", max_output_tokens, out_so_far, on_exceed)
?? __budget_check("max_total_tokens", max_total_tokens, in_so_far + out_so_far, on_exceed)
if exhaustion != nil {
return exhaustion
}
let envelope = __safe_invoke(next, call)
let _ = atomic_add(total_calls, 1)
if envelope?.ok ?? false && type_of(envelope?.value) == "dict" {
let usage = __opts_dict(envelope.value?.usage)
let in_tokens = to_int(usage?.input_tokens ?? envelope.value?.input_tokens ?? 0)
let out_tokens = to_int(usage?.output_tokens ?? envelope.value?.output_tokens ?? 0)
let _ = atomic_add(total_in, in_tokens)
let _ = atomic_add(total_out, out_tokens)
}
return envelope
}
}
// -------------------------------------------------------------------------------------------------
// with_cache
// -------------------------------------------------------------------------------------------------
/** llm_cache_key returns the canonical sha256 cache key used by with_cache. */
pub fn llm_cache_key(prompt, system = nil, options = nil) -> string {
return __llm_cache_key(prompt, system, options ?? {})
}
/**
* with_cache supports both public cache forms:
*
* with_cache(prompt, system?, options?) -> llm_call envelope
* with_cache(next, opts?) -> caller
*
* The caller-wrapper form remains a transparent middleware for now; direct
* calls use the persistent runtime cache.
*/
pub fn with_cache(first, second = nil, third = nil) {
if __llm_handler_is_callable(first) {
let next = first
return { call -> return __safe_invoke(next, call) }
}
let prompt = first
let system = second
let opts = third ?? {}
if __llm_handler_should_skip_cache(prompt, system, opts) {
return llm_call(prompt, system, opts)
}
let key = llm_cache_key(prompt, system, opts)
let cache_options = __llm_handler_cache_options(opts)
let cached = cache_get(key, cache_options)
if cached.hit {
return cached.value
}
let result = llm_call(prompt, system, opts)
cache_put(key, result, cache_options)
return result
}
// -------------------------------------------------------------------------------------------------
// with_circuit_breaker
// -------------------------------------------------------------------------------------------------
/**
* Wrap an LLM call handler with circuit-breaker protection.
*
* By default each invocation uses a circuit derived from the call's
* `(opts.provider, opts.model)` pair, so one failing upstream cannot poison
* other models routed through the same wrapper. Pass `name` to intentionally
* share one circuit across calls.
*/
pub fn with_circuit_breaker(handler, options = nil) {
if !__llm_handler_is_callable(handler) {
throw "with_circuit_breaker: handler must be callable"
}
let opts = options ?? {}
let threshold = opts?.threshold ?? 5
let reset_ms = opts?.reset_ms ?? 30000
return { call ->
let name = opts?.name ?? __circuit_name_for(call)
__ensure_llm_handler_circuit(name, threshold, reset_ms)
let state = circuit_check(name)
if state == "open" {
throw __circuit_open_error(call, name)
}
let outcome = try {
handler(call)
}
if is_err(outcome) {
circuit_record_failure(name)
throw unwrap_err(outcome)
}
let result = unwrap(outcome)
if type_of(result) == "dict" && contains(result.keys(), "ok") && !(result?.ok ?? false) {
circuit_record_failure(name)
} else {
circuit_record_success(name)
}
return result
}
}
// -------------------------------------------------------------------------------------------------
// compose
// -------------------------------------------------------------------------------------------------
/**
* compose(wrappers) -> fn(base) -> caller
*
* Harn does not (yet) support user-defined variadic functions, so `compose`
* accepts a LIST of wrappers. Each wrapper is `fn(next) -> caller`.
*
* Wrappers apply right-to-left so that
*
* compose([a, b, c])(base) == a(b(c(base)))
*
* Equivalently, the leftmost wrapper is the outermost.
*/
pub fn compose(wrappers) {
let list = if type_of(wrappers) == "list" {
wrappers
} else {
[]
}
return { base ->
var caller = base
var i = len(list) - 1
while i >= 0 {
let w = list[i]
caller = w(caller)
i = i - 1
}
return caller
}
}