use std::future::Future;
#[cfg(feature = "observe")]
use std::sync::atomic::{AtomicU64, Ordering};
#[cfg(feature = "observe")]
use std::time::{SystemTime, UNIX_EPOCH};
use rig_core::agent::{HookAction, PromptHook};
use rig_core::completion::{CompletionModel, CompletionResponse};
use rig_core::message::Message;
use rig_core::wasm_compat::WasmCompatSend;
use crate::{ModelDescriptor, ModelMetaProbe, ProbeError, ProviderId};
#[derive(Debug, Clone)]
pub struct MetaHook {
provider: ProviderId,
model: String,
descriptor: Option<ModelDescriptor>,
#[cfg(feature = "observe")]
observe_conversation_id: String,
}
impl MetaHook {
pub async fn resolve<P>(
probe: &P,
provider: impl Into<ProviderId>,
model: impl Into<String>,
) -> Result<Self, ProbeError>
where
P: ModelMetaProbe + ?Sized,
{
let model = model.into();
let descriptor = probe.describe(&model).await?;
Ok(Self {
provider: provider.into(),
model,
descriptor,
#[cfg(feature = "observe")]
observe_conversation_id: "default".into(),
})
}
pub fn unresolved(provider: impl Into<ProviderId>, model: impl Into<String>) -> Self {
Self {
provider: provider.into(),
model: model.into(),
descriptor: None,
#[cfg(feature = "observe")]
observe_conversation_id: "default".into(),
}
}
pub fn from_descriptor(
provider: impl Into<ProviderId>,
model: impl Into<String>,
descriptor: Option<ModelDescriptor>,
) -> Self {
Self {
provider: provider.into(),
model: model.into(),
descriptor,
#[cfg(feature = "observe")]
observe_conversation_id: "default".into(),
}
}
#[cfg(feature = "observe")]
#[must_use]
pub fn with_observe_conversation_id(mut self, conversation_id: impl Into<String>) -> Self {
self.observe_conversation_id = conversation_id.into();
self
}
#[cfg(feature = "observe")]
pub fn observe_conversation_id(&self) -> &str {
&self.observe_conversation_id
}
pub fn descriptor(&self) -> Option<&ModelDescriptor> {
self.descriptor.as_ref()
}
pub fn provider(&self) -> &ProviderId {
&self.provider
}
pub fn model(&self) -> &str {
&self.model
}
fn context_used_pct(&self, input_tokens: u64) -> Option<f64> {
self.descriptor
.as_ref()
.and_then(|d| d.context_window)
.and_then(|w| {
if w == 0 {
None
} else {
Some(input_tokens as f64 / w as f64 * 100.0)
}
})
}
}
impl<M> PromptHook<M> for MetaHook
where
M: CompletionModel,
{
fn on_completion_call(
&self,
_prompt: &Message,
_history: &[Message],
) -> impl Future<Output = HookAction> + WasmCompatSend {
let window = self.descriptor.as_ref().and_then(|d| d.context_window);
tracing::info!(
target: "rig_model_meta::hook",
gen_ai_system = %self.provider,
gen_ai_request_model = %self.model,
gen_ai_model_context_window = window,
"completion call start",
);
#[cfg(feature = "observe")]
emit_observe_prompt_started(
&self.observe_conversation_id,
&self.model,
_history.len().saturating_add(1),
);
async { HookAction::cont() }
}
fn on_completion_response(
&self,
_prompt: &Message,
response: &CompletionResponse<M::Response>,
) -> impl Future<Output = HookAction> + WasmCompatSend {
let usage = response.usage;
let window = self.descriptor.as_ref().and_then(|d| d.context_window);
let pct = self.context_used_pct(usage.input_tokens);
tracing::info!(
target: "rig_model_meta::hook",
gen_ai_system = %self.provider,
gen_ai_response_model = %self.model,
gen_ai_usage_input_tokens = usage.input_tokens,
gen_ai_usage_output_tokens = usage.output_tokens,
gen_ai_usage_total_tokens = usage.total_tokens,
gen_ai_model_context_window = window,
gen_ai_usage_context_used_pct = pct,
"completion call complete",
);
#[cfg(feature = "observe")]
emit_observe_prompt_completed(
&self.observe_conversation_id,
&self.model,
positive(usage.input_tokens),
positive(usage.output_tokens),
);
async { HookAction::cont() }
}
}
#[cfg(feature = "observe")]
static OBSERVE_TICK: AtomicU64 = AtomicU64::new(1);
#[cfg(feature = "observe")]
fn emit_observe_prompt_started(conversation_id: &str, model: &str, messages_in: usize) {
let mut event = observe_envelope(conversation_id, "prompt.started");
event.insert("model".into(), serde_json::Value::String(model.into()));
event.insert("messages_in".into(), serde_json::json!(messages_in));
emit_observe_event(event);
}
#[cfg(feature = "observe")]
fn emit_observe_prompt_completed(
conversation_id: &str,
model: &str,
tokens_in: Option<u64>,
tokens_out: Option<u64>,
) {
let mut event = observe_envelope(conversation_id, "prompt.completed");
event.insert("model".into(), serde_json::Value::String(model.into()));
if let Some(tokens_in) = tokens_in {
event.insert("tokens_in".into(), serde_json::json!(tokens_in));
}
if let Some(tokens_out) = tokens_out {
event.insert("tokens_out".into(), serde_json::json!(tokens_out));
}
emit_observe_event(event);
}
#[cfg(feature = "observe")]
fn observe_envelope(
conversation_id: &str,
kind: &str,
) -> serde_json::Map<String, serde_json::Value> {
let mut event = serde_json::Map::new();
event.insert("version".into(), serde_json::json!(1));
event.insert("occurred_at_millis".into(), serde_json::json!(now_millis()));
event.insert(
"tick".into(),
serde_json::json!(OBSERVE_TICK.fetch_add(1, Ordering::Relaxed)),
);
event.insert(
"conversation_id".into(),
serde_json::Value::String(conversation_id.into()),
);
event.insert("kind".into(), serde_json::Value::String(kind.into()));
event
}
#[cfg(feature = "observe")]
fn emit_observe_event(event: serde_json::Map<String, serde_json::Value>) {
if let Ok(json) = serde_json::to_string(&event) {
tracing::info!(target: "rig_tap", event = %json);
}
}
#[cfg(feature = "observe")]
fn now_millis() -> u64 {
match SystemTime::now().duration_since(UNIX_EPOCH) {
Ok(duration) => u64::try_from(duration.as_millis()).unwrap_or(u64::MAX),
Err(_) => 0,
}
}
#[cfg(feature = "observe")]
fn positive(value: u64) -> Option<u64> {
if value == 0 { None } else { Some(value) }
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::panic, clippy::indexing_slicing)]
mod tests {
use super::*;
use crate::{ModelDescriptor, StubProbe};
fn descriptor() -> ModelDescriptor {
ModelDescriptor::builder("ollama", "qwen3.5:9b")
.context_window(131_072)
.build()
}
#[tokio::test]
async fn resolve_caches_descriptor_from_probe() {
let probe = StubProbe::new([("qwen3.5:9b", descriptor())]);
let hook = MetaHook::resolve(&probe, "ollama", "qwen3.5:9b")
.await
.unwrap();
assert_eq!(hook.provider().as_str(), "ollama");
assert_eq!(hook.model(), "qwen3.5:9b");
assert_eq!(hook.descriptor().unwrap().context_window, Some(131_072));
}
#[tokio::test]
async fn resolve_tolerates_unknown_model() {
let probe = StubProbe::default();
let hook = MetaHook::resolve(&probe, "ollama", "unknown")
.await
.unwrap();
assert!(hook.descriptor().is_none());
}
#[test]
fn context_used_pct_computes_against_window() {
let hook = MetaHook::from_descriptor("ollama", "qwen3.5:9b", Some(descriptor()));
let pct = hook.context_used_pct(65_536).unwrap();
assert!((pct - 50.0).abs() < 1e-9);
}
#[test]
fn context_used_pct_is_none_when_window_unknown() {
let hook = MetaHook::unresolved("openai", "gpt-4o");
assert!(hook.context_used_pct(1000).is_none());
}
#[test]
fn context_used_pct_handles_zero_window() {
let desc = ModelDescriptor::builder("p", "m").context_window(0).build();
let hook = MetaHook::from_descriptor("p", "m", Some(desc));
assert!(hook.context_used_pct(100).is_none());
}
#[cfg(feature = "observe")]
#[test]
fn observe_envelope_matches_rig_tap_shape() {
let event = observe_envelope("thread-1", "prompt.started");
assert_eq!(event.get("version").unwrap(), &serde_json::json!(1));
assert_eq!(
event.get("conversation_id").unwrap(),
&serde_json::json!("thread-1")
);
assert_eq!(
event.get("kind").unwrap(),
&serde_json::json!("prompt.started")
);
assert!(event.get("occurred_at_millis").is_some());
assert!(event.get("tick").is_some());
}
#[cfg(feature = "observe")]
#[test]
fn observe_conversation_id_is_configurable() {
let hook =
MetaHook::unresolved("ollama", "qwen3.5:9b").with_observe_conversation_id("thread-42");
assert_eq!(hook.observe_conversation_id(), "thread-42");
}
}