use async_trait::async_trait;
use crate::ai::provider::{CompletionRequest, CompletionResponse, StreamChunk};
use crate::errors::{NoosError, NoosResult};
use crate::types::intervention::{
CognitiveState, InterventionDepth, LogitBias, SamplingOverride,
};
#[async_trait]
pub trait InferenceProvider: Send + Sync {
fn intervention_depth(&self) -> InterventionDepth;
async fn complete_with_override(
&self,
request: CompletionRequest,
sampling: SamplingOverride,
) -> NoosResult<CompletionResponse>;
async fn stream_with_override(
&self,
request: CompletionRequest,
sampling: SamplingOverride,
sender: tokio::sync::mpsc::Sender<StreamChunk>,
) -> NoosResult<()>;
async fn get_next_token_logits(
&self,
_request: CompletionRequest,
) -> NoosResult<Vec<f32>> {
Err(NoosError::UnsupportedIntervention(format!(
"get_next_token_logits requires {:?}, model supports {:?}",
InterventionDepth::LogitAccess,
self.intervention_depth(),
)))
}
}
pub trait LogitIntervenor {
fn compute_logit_biases(&self, state: &CognitiveState) -> Vec<LogitBias>;
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn logit_intervenor_can_return_empty() {
struct NoOpIntervenor;
impl LogitIntervenor for NoOpIntervenor {
fn compute_logit_biases(&self, _state: &CognitiveState) -> Vec<LogitBias> {
Vec::new()
}
}
let intervenor = NoOpIntervenor;
let state = CognitiveState::default();
let biases = intervenor.compute_logit_biases(&state);
assert!(biases.is_empty());
}
#[test]
fn logit_intervenor_returns_biases() {
struct MockIntervenor;
impl LogitIntervenor for MockIntervenor {
fn compute_logit_biases(&self, state: &CognitiveState) -> Vec<LogitBias> {
if state.arousal > 0.6 {
vec![LogitBias {
token_id: 100,
bias: -2.0,
source: "mock".into(),
}]
} else {
Vec::new()
}
}
}
let intervenor = MockIntervenor;
let calm = CognitiveState::default();
assert!(intervenor.compute_logit_biases(&calm).is_empty());
let aroused = CognitiveState {
arousal: 0.8,
..CognitiveState::default()
};
let biases = intervenor.compute_logit_biases(&aroused);
assert_eq!(biases.len(), 1);
assert_eq!(biases[0].bias, -2.0);
}
}