#![allow(clippy::collapsible_if)]
use async_trait::async_trait;
use super::types::{HiddenStateConfig, ModelHiddenStates, ModelKVCache};
use crate::error::HiddenStateError;
#[async_trait]
pub trait HiddenStateProvider: Send + Sync {
async fn extract_hidden_states(
&self,
text: &str,
) -> Result<ModelHiddenStates, HiddenStateError>;
async fn extract_with_kv_cache(
&self,
text: &str,
past_kv: Option<&ModelKVCache>,
) -> Result<(ModelHiddenStates, ModelKVCache), HiddenStateError>;
fn model_config(&self) -> &HiddenStateConfig;
fn model_id(&self) -> &str;
fn num_layers(&self) -> usize;
fn hidden_dim(&self) -> usize;
fn supports_attention_weights(&self) -> bool {
self.model_config().capture_attention_weights
}
fn supports_kv_cache(&self) -> bool {
true
}
}
pub trait StateReuseStrategy: Send + Sync {
fn can_reuse(&self, cached: &ModelHiddenStates, new_text: &str, cached_text: &str) -> bool;
fn reuse_point(&self, cached: &ModelHiddenStates, new_text: &str, cached_text: &str) -> usize;
fn reuse_quality(&self, cached: &ModelHiddenStates, new_text: &str, cached_text: &str) -> f32;
fn description(&self) -> &'static str;
}
pub type BoxedStateReuseStrategy = Box<dyn StateReuseStrategy>;
#[async_trait]
pub trait HiddenStateProviderExt: HiddenStateProvider {
async fn extract_and_cache(
&self,
text: &str,
cache: &mut super::cache::HiddenStateCache,
) -> Result<ModelHiddenStates, HiddenStateError> {
if let Some(cached) = cache.get(text) {
return Ok(cached.states.clone());
}
let states = self.extract_hidden_states(text).await?;
cache.put(text.to_string(), states.clone(), None);
Ok(states)
}
async fn extract_with_caching(
&self,
text: &str,
past_kv: Option<&ModelKVCache>,
cache: &mut super::cache::HiddenStateCache,
) -> Result<(ModelHiddenStates, ModelKVCache), HiddenStateError> {
if past_kv.is_none() {
if let Some(cached) = cache.get(text) {
if let Some(ref kv) = cached.kv_cache {
return Ok((cached.states.clone(), kv.clone()));
}
}
}
let (states, kv) = self.extract_with_kv_cache(text, past_kv).await?;
if past_kv.is_none() {
cache.put(text.to_string(), states.clone(), Some(kv.clone()));
}
Ok((states, kv))
}
}
impl<T: HiddenStateProvider> HiddenStateProviderExt for T {}
#[cfg(test)]
mod tests {
use super::*;
use crate::hidden_states::extractor::MockHiddenStateProvider;
#[tokio::test]
async fn test_mock_provider_extract() {
let provider = MockHiddenStateProvider::new("test-model", 12, 768);
let states = provider
.extract_hidden_states("Hello, world!")
.await
.unwrap();
assert_eq!(states.model_id, "test-model");
assert_eq!(states.num_layers, 12);
assert_eq!(states.hidden_dim, 768);
}
#[tokio::test]
async fn test_mock_provider_with_kv_cache() {
let provider = MockHiddenStateProvider::new("test-model", 6, 512);
let (states, kv) = provider
.extract_with_kv_cache("Test input", None)
.await
.unwrap();
assert_eq!(states.model_id, "test-model");
assert_eq!(kv.model_id, "test-model");
assert_eq!(kv.layers.len(), 6);
}
#[test]
fn test_provider_trait_methods() {
let provider = MockHiddenStateProvider::new("test-model", 12, 768);
assert_eq!(provider.model_id(), "test-model");
assert_eq!(provider.num_layers(), 12);
assert_eq!(provider.hidden_dim(), 768);
assert!(provider.supports_kv_cache());
}
}