use std::sync::Arc;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use rakka_core::actor::{Actor, Context};
use tokio::sync::oneshot;
use inference_core::deployment::Timeouts;
use inference_core::error::{InferenceError, InferenceResult};
use inference_core::runner::SessionRebuildCause;
use inference_core::SecretString;
use crate::http::{build_client, HttpClient};
#[derive(Clone)]
pub struct SessionConfig {
pub user_agent: String,
pub timeouts: Timeouts,
pub credential: Arc<dyn CredentialProvider>,
}
#[async_trait]
pub trait CredentialProvider: Send + Sync {
async fn token(&self) -> InferenceResult<SecretString>;
}
pub struct StaticApiKey(pub SecretString);
#[async_trait]
impl CredentialProvider for StaticApiKey {
async fn token(&self) -> InferenceResult<SecretString> {
use inference_core::ExposeSecret;
Ok(SecretString::from(self.0.expose_secret().to_string()))
}
}
pub struct SessionSnapshot {
pub client: HttpClient,
pub credential: SecretString,
}
pub struct SessionRebuildRequest {
pub cause: SessionRebuildCause,
pub reply: oneshot::Sender<InferenceResult<()>>,
}
pub struct RemoteSessionActor {
config: SessionConfig,
snapshot: Arc<ArcSwap<SessionSnapshot>>,
}
impl RemoteSessionActor {
pub async fn bootstrap(config: SessionConfig) -> InferenceResult<Self> {
let snapshot = Self::build_snapshot(&config).await?;
Ok(Self {
config,
snapshot: Arc::new(ArcSwap::from_pointee(snapshot)),
})
}
pub fn snapshot(&self) -> Arc<ArcSwap<SessionSnapshot>> {
self.snapshot.clone()
}
async fn build_snapshot(config: &SessionConfig) -> InferenceResult<SessionSnapshot> {
let client = build_client(&config.timeouts, &config.user_agent)
.map_err(|e| InferenceError::Internal(format!("build http client: {e}")))?;
let credential = config.credential.token().await?;
Ok(SessionSnapshot { client, credential })
}
async fn rebuild(&mut self, cause: SessionRebuildCause) -> InferenceResult<()> {
tracing::info!(?cause, "rebuilding remote session");
let snap = Self::build_snapshot(&self.config).await?;
self.snapshot.store(Arc::new(snap));
Ok(())
}
}
#[async_trait]
impl Actor for RemoteSessionActor {
type Msg = SessionRebuildRequest;
async fn handle(&mut self, _ctx: &mut Context<Self>, msg: Self::Msg) {
let res = self.rebuild(msg.cause).await;
let _ = msg.reply.send(res);
}
}