use std::sync::Arc;
use std::collections::HashMap;
use chrono::{DateTime, Utc, Duration as ChronoDuration};
use serde::{Serialize, Deserialize};
use tokio::sync::RwLock;
use crate::{SaTokenError, SaTokenResult, SaTokenManager};
type LogoutCallback = Arc<dyn Fn(&str) -> bool + Send + Sync>;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoTicket {
pub ticket_id: String,
pub service: String,
pub login_id: String,
pub create_time: DateTime<Utc>,
pub expire_time: DateTime<Utc>,
pub used: bool,
}
impl SsoTicket {
pub fn new(login_id: String, service: String, timeout_seconds: i64) -> Self {
let now = Utc::now();
Self {
ticket_id: uuid::Uuid::new_v4().to_string(),
service,
login_id,
create_time: now,
expire_time: now + ChronoDuration::seconds(timeout_seconds),
used: false,
}
}
pub fn is_expired(&self) -> bool {
Utc::now() > self.expire_time
}
pub fn is_valid(&self) -> bool {
!self.used && !self.is_expired()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoSession {
pub login_id: String,
pub clients: Vec<String>,
pub create_time: DateTime<Utc>,
pub last_active_time: DateTime<Utc>,
}
impl SsoSession {
pub fn new(login_id: String) -> Self {
let now = Utc::now();
Self {
login_id,
clients: Vec::new(),
create_time: now,
last_active_time: now,
}
}
pub fn add_client(&mut self, service: String) {
if !self.clients.contains(&service) {
self.clients.push(service);
}
self.last_active_time = Utc::now();
}
pub fn remove_client(&mut self, service: &str) {
self.clients.retain(|c| c != service);
self.last_active_time = Utc::now();
}
}
pub struct SsoServer {
manager: Arc<SaTokenManager>,
tickets: Arc<RwLock<HashMap<String, SsoTicket>>>,
sessions: Arc<RwLock<HashMap<String, SsoSession>>>,
ticket_timeout: i64,
}
impl SsoServer {
pub fn new(manager: Arc<SaTokenManager>) -> Self {
Self {
manager,
tickets: Arc::new(RwLock::new(HashMap::new())),
sessions: Arc::new(RwLock::new(HashMap::new())),
ticket_timeout: 300, }
}
pub fn with_ticket_timeout(mut self, timeout: i64) -> Self {
self.ticket_timeout = timeout;
self
}
pub async fn is_logged_in(&self, login_id: &str) -> bool {
let sessions = self.sessions.read().await;
let has_session = sessions.contains_key(login_id);
drop(sessions);
if has_session {
let key = format!("sa:login:token:{}:sso", login_id);
matches!(self.manager.storage.get(&key).await, Ok(Some(_)))
} else {
false
}
}
pub async fn create_ticket(&self, login_id: String, service: String) -> SaTokenResult<SsoTicket> {
let ticket = SsoTicket::new(login_id.clone(), service.clone(), self.ticket_timeout);
let mut tickets = self.tickets.write().await;
tickets.insert(ticket.ticket_id.clone(), ticket.clone());
let mut sessions = self.sessions.write().await;
sessions.entry(login_id.clone())
.or_insert_with(|| SsoSession::new(login_id))
.add_client(service);
Ok(ticket)
}
pub async fn validate_ticket(&self, ticket_id: &str, service: &str) -> SaTokenResult<String> {
let mut tickets = self.tickets.write().await;
let ticket = tickets.get_mut(ticket_id)
.ok_or(SaTokenError::InvalidTicket)?;
if !ticket.is_valid() {
return Err(SaTokenError::TicketExpired);
}
if ticket.service != service {
return Err(SaTokenError::ServiceMismatch);
}
ticket.used = true;
let login_id = ticket.login_id.clone();
Ok(login_id)
}
pub async fn login(&self, login_id: String, service: String) -> SaTokenResult<SsoTicket> {
let _token = self.manager.login_with_options(
&login_id,
Some("sso".to_string()), None,
Some(serde_json::json!({
"sso_mode": true,
"service": service.clone()
})),
None,
None,
).await?;
let mut sessions = self.sessions.write().await;
sessions.entry(login_id.clone())
.or_insert_with(|| SsoSession::new(login_id.clone()))
.add_client(service.clone());
drop(sessions);
self.create_ticket(login_id, service).await
}
pub async fn logout(&self, login_id: &str) -> SaTokenResult<Vec<String>> {
let mut sessions = self.sessions.write().await;
let session = sessions.remove(login_id);
let clients = session.map(|s| s.clients).unwrap_or_default();
drop(sessions);
let sso_key = format!("sa:login:token:{}:sso", login_id);
let _ = self.manager.storage.delete(&sso_key).await;
self.manager.logout_by_login_id(login_id).await?;
Ok(clients)
}
pub async fn get_session(&self, login_id: &str) -> Option<SsoSession> {
let sessions = self.sessions.read().await;
sessions.get(login_id).cloned()
}
pub async fn check_session(&self, login_id: &str) -> bool {
let sessions = self.sessions.read().await;
sessions.contains_key(login_id)
}
pub async fn cleanup_expired_tickets(&self) {
let mut tickets = self.tickets.write().await;
tickets.retain(|_, ticket| ticket.is_valid());
}
pub async fn get_active_clients(&self, login_id: &str) -> Vec<String> {
let sessions = self.sessions.read().await;
sessions.get(login_id)
.map(|s| s.clients.clone())
.unwrap_or_default()
}
}
pub struct SsoClient {
manager: Arc<SaTokenManager>,
server_url: String,
service_url: String,
logout_callback: Option<LogoutCallback>,
}
impl SsoClient {
pub fn new(
manager: Arc<SaTokenManager>,
server_url: String,
service_url: String,
) -> Self {
Self {
manager,
server_url,
service_url,
logout_callback: None,
}
}
pub fn with_logout_callback<F>(mut self, callback: F) -> Self
where
F: Fn(&str) -> bool + Send + Sync + 'static,
{
self.logout_callback = Some(Arc::new(callback));
self
}
pub fn get_login_url(&self) -> String {
format!("{}?service={}", self.server_url, urlencoding::encode(&self.service_url))
}
pub fn get_logout_url(&self) -> String {
format!("{}/logout?service={}", self.server_url, urlencoding::encode(&self.service_url))
}
pub async fn check_local_login(&self, login_id: &str) -> bool {
let key = format!("sa:login:token:{}:sso_client", login_id);
match self.manager.storage.get(&key).await {
Ok(Some(_)) => true,
_ => {
let key_default = format!("sa:login:token:{}", login_id);
matches!(self.manager.storage.get(&key_default).await, Ok(Some(_)))
}
}
}
pub async fn process_ticket(&self, ticket: &str, service: &str) -> SaTokenResult<String> {
if service != self.service_url {
return Err(SaTokenError::ServiceMismatch);
}
Ok(ticket.to_string())
}
pub async fn login_by_ticket(&self, login_id: String) -> SaTokenResult<String> {
let token = self.manager.login_with_options(
&login_id,
Some("sso_client".to_string()), None,
Some(serde_json::json!({
"sso_client": true,
"service_url": self.service_url.clone()
})),
None,
None,
).await?;
Ok(token.to_string())
}
pub async fn handle_logout(&self, login_id: &str) -> SaTokenResult<()> {
if let Some(callback) = &self.logout_callback {
callback(login_id);
}
let sso_client_key = format!("sa:login:token:{}:sso_client", login_id);
let _ = self.manager.storage.delete(&sso_client_key).await;
self.manager.logout_by_login_id(login_id).await?;
Ok(())
}
pub fn server_url(&self) -> &str {
&self.server_url
}
pub fn service_url(&self) -> &str {
&self.service_url
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SsoConfig {
pub server_url: String,
pub ticket_timeout: i64,
pub allow_cross_domain: bool,
pub allowed_origins: Vec<String>,
}
impl Default for SsoConfig {
fn default() -> Self {
Self {
server_url: "http://localhost:8080/sso".to_string(),
ticket_timeout: 300,
allow_cross_domain: true,
allowed_origins: vec!["*".to_string()],
}
}
}
impl SsoConfig {
pub fn builder() -> SsoConfigBuilder {
SsoConfigBuilder::default()
}
}
#[derive(Default)]
pub struct SsoConfigBuilder {
config: SsoConfig,
}
impl SsoConfigBuilder {
pub fn server_url(mut self, url: impl Into<String>) -> Self {
self.config.server_url = url.into();
self
}
pub fn ticket_timeout(mut self, timeout: i64) -> Self {
self.config.ticket_timeout = timeout;
self
}
pub fn allow_cross_domain(mut self, allow: bool) -> Self {
self.config.allow_cross_domain = allow;
self
}
pub fn allowed_origins(mut self, origins: Vec<String>) -> Self {
self.config.allowed_origins = origins;
self
}
pub fn add_allowed_origin(mut self, origin: String) -> Self {
if self.config.allowed_origins == vec!["*".to_string()] {
self.config.allowed_origins = vec![origin];
} else {
self.config.allowed_origins.push(origin);
}
self
}
pub fn build(self) -> SsoConfig {
self.config
}
}
pub struct SsoManager {
server: Option<Arc<SsoServer>>,
client: Option<Arc<SsoClient>>,
config: SsoConfig,
}
impl SsoManager {
pub fn new(config: SsoConfig) -> Self {
Self {
server: None,
client: None,
config,
}
}
pub fn with_server(mut self, server: Arc<SsoServer>) -> Self {
self.server = Some(server);
self
}
pub fn with_client(mut self, client: Arc<SsoClient>) -> Self {
self.client = Some(client);
self
}
pub fn server(&self) -> Option<&Arc<SsoServer>> {
self.server.as_ref()
}
pub fn client(&self) -> Option<&Arc<SsoClient>> {
self.client.as_ref()
}
pub fn config(&self) -> &SsoConfig {
&self.config
}
pub fn is_allowed_origin(&self, origin: &str) -> bool {
if !self.config.allow_cross_domain {
return false;
}
self.config.allowed_origins.contains(&"*".to_string()) ||
self.config.allowed_origins.contains(&origin.to_string())
}
}