use async_trait::async_trait;
use super::{Session, SessionError, SessionId, SessionService};
use crate::events::Event;
#[derive(Debug, Clone)]
pub struct VertexAiSessionConfig {
pub project: String,
pub location: String,
pub ttl_seconds: Option<u64>,
}
impl VertexAiSessionConfig {
pub fn new(project: impl Into<String>, location: impl Into<String>) -> Self {
Self {
project: project.into(),
location: location.into(),
ttl_seconds: None,
}
}
pub fn ttl_seconds(mut self, ttl: u64) -> Self {
self.ttl_seconds = Some(ttl);
self
}
fn base_url(&self) -> String {
format!(
"https://{location}-aiplatform.googleapis.com/v1beta1/projects/{project}/locations/{location}",
project = self.project,
location = self.location,
)
}
fn sessions_url(&self, engine_id: &str) -> String {
format!(
"{}/reasoningEngines/{}/sessions",
self.base_url(),
engine_id,
)
}
fn session_url(&self, engine_id: &str, session_id: &str) -> String {
format!("{}/{}", self.sessions_url(engine_id), session_id)
}
fn events_url(&self, engine_id: &str, session_id: &str) -> String {
format!("{}/events", self.session_url(engine_id, session_id))
}
}
pub struct VertexAiSessionService {
config: VertexAiSessionConfig,
}
impl VertexAiSessionService {
pub fn new(config: VertexAiSessionConfig) -> Self {
Self { config }
}
pub fn project(&self) -> &str {
&self.config.project
}
pub fn location(&self) -> &str {
&self.config.location
}
pub fn ttl_seconds(&self) -> Option<u64> {
self.config.ttl_seconds
}
}
#[async_trait]
impl SessionService for VertexAiSessionService {
async fn create_session(&self, app_name: &str, user_id: &str) -> Result<Session, SessionError> {
let _url = self.config.sessions_url(app_name);
let _user = user_id;
let _ttl_body = self
.config
.ttl_seconds
.map(|t| format!("\"ttl\": \"{t}s\""));
todo!("POST to {_url} to create Vertex AI session for user={_user}")
}
async fn get_session(&self, id: &SessionId) -> Result<Option<Session>, SessionError> {
let _url = self.config.session_url("default", id.as_str());
todo!("GET {_url} to fetch Vertex AI session")
}
async fn list_sessions(
&self,
app_name: &str,
user_id: &str,
) -> Result<Vec<Session>, SessionError> {
let _url = self.config.sessions_url(app_name);
let _user = user_id;
todo!("GET {_url} to list Vertex AI sessions for user={_user}")
}
async fn delete_session(&self, id: &SessionId) -> Result<(), SessionError> {
let _url = self.config.session_url("default", id.as_str());
todo!("DELETE {_url} to remove Vertex AI session")
}
async fn append_event(&self, id: &SessionId, event: Event) -> Result<(), SessionError> {
let _url = self.config.events_url("default", id.as_str());
let _event_json =
serde_json::to_value(&event).map_err(|e| SessionError::Storage(e.to_string()))?;
todo!("POST to {_url} to append event to Vertex AI session")
}
async fn get_events(&self, id: &SessionId) -> Result<Vec<Event>, SessionError> {
let _url = self.config.events_url("default", id.as_str());
todo!("GET {_url} to fetch events for Vertex AI session")
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn config_new() {
let config = VertexAiSessionConfig::new("my-project", "us-central1");
assert_eq!(config.project, "my-project");
assert_eq!(config.location, "us-central1");
assert!(config.ttl_seconds.is_none());
}
#[test]
fn config_with_ttl() {
let config = VertexAiSessionConfig::new("proj", "us-east1").ttl_seconds(3600);
assert_eq!(config.ttl_seconds, Some(3600));
}
#[test]
fn url_construction() {
let config = VertexAiSessionConfig::new("my-project", "us-central1");
assert_eq!(
config.base_url(),
"https://us-central1-aiplatform.googleapis.com/v1beta1/projects/my-project/locations/us-central1"
);
assert!(config
.sessions_url("engine-1")
.contains("reasoningEngines/engine-1/sessions"));
assert!(config
.session_url("engine-1", "sess-1")
.contains("sessions/sess-1"));
assert!(config
.events_url("engine-1", "sess-1")
.contains("sessions/sess-1/events"));
}
#[test]
fn service_accessors() {
let svc = VertexAiSessionService::new(
VertexAiSessionConfig::new("proj", "us-west1").ttl_seconds(7200),
);
assert_eq!(svc.project(), "proj");
assert_eq!(svc.location(), "us-west1");
assert_eq!(svc.ttl_seconds(), Some(7200));
}
}