Skip to main content

gcloud_spanner/
session.rs

1use std::collections::VecDeque;
2use std::mem;
3use std::ops::{Deref, DerefMut};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use parking_lot::{Mutex, RwLock};
8use thiserror;
9use tokio::select;
10use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
11use tokio::sync::{mpsc, oneshot};
12use tokio::task::{JoinHandle, JoinSet};
13use tokio::time::{sleep, timeout};
14use tokio_util::sync::CancellationToken;
15
16use tracing::{Instrument, Span};
17
18use google_cloud_gax::grpc::metadata::MetadataMap;
19use google_cloud_gax::grpc::{Code, Status};
20use google_cloud_gax::retry::TryAs;
21use google_cloud_googleapis::spanner::v1::{BatchCreateSessionsRequest, DeleteSessionRequest, Session};
22
23use crate::apiv1::conn_pool::ConnectionManager;
24use crate::apiv1::spanner_client::{ping_query_request, Client};
25use crate::metrics::{MetricsRecorder, SessionPoolSnapshot, SessionPoolStatsFn};
26
27const MAX_IN_USE_WINDOW: Duration = Duration::from_secs(600);
28
29/// Session
30pub struct SessionHandle {
31    pub session: Session,
32    pub spanner_client: Client,
33    valid: bool,
34    deleted: bool,
35    last_used_at: Instant,
36    last_checked_at: Instant,
37    last_pong_at: Instant,
38    created_at: Instant,
39}
40
41impl SessionHandle {
42    pub(crate) fn new(session: Session, spanner_client: Client, now: Instant) -> SessionHandle {
43        SessionHandle {
44            session,
45            spanner_client,
46            valid: true,
47            deleted: false,
48            last_used_at: now,
49            last_checked_at: now,
50            last_pong_at: now,
51            created_at: now,
52        }
53    }
54
55    pub async fn invalidate_if_needed<T>(&mut self, arg: Result<T, Status>) -> Result<T, Status> {
56        match arg {
57            Ok(s) => Ok(s),
58            Err(e) => {
59                if e.code() == Code::NotFound && e.message().contains("Session not found:") {
60                    tracing::debug!("session invalidate {}", self.session.name);
61                    self.delete().await;
62                }
63                Err(e)
64            }
65        }
66    }
67
68    async fn delete(&mut self) {
69        self.valid = false;
70        let session_name = &self.session.name;
71        let request = DeleteSessionRequest {
72            name: session_name.to_string(),
73        };
74        match self.spanner_client.delete_session(request, true, None).await {
75            Ok(_) => self.deleted = true,
76            Err(e) => tracing::warn!("failed to delete session {}, {:?}", session_name, e),
77        };
78    }
79}
80
81/// ManagedSession
82pub struct ManagedSession {
83    session_pool: SessionPool,
84    session: Option<SessionHandle>,
85}
86
87impl ManagedSession {
88    fn new(session_pool: SessionPool, session: SessionHandle) -> Self {
89        ManagedSession {
90            session_pool,
91            session: Some(session),
92        }
93    }
94}
95
96impl Drop for ManagedSession {
97    fn drop(&mut self) {
98        let session = self.session.take().unwrap();
99        self.session_pool.recycle(session);
100    }
101}
102
103impl Deref for ManagedSession {
104    type Target = SessionHandle;
105
106    fn deref(&self) -> &Self::Target {
107        self.session.as_ref().unwrap()
108    }
109}
110
111impl DerefMut for ManagedSession {
112    fn deref_mut(&mut self) -> &mut Self::Target {
113        self.session.as_mut().unwrap()
114    }
115}
116
117/// Sessions have all sessions and waiters.
118/// This is for atomically locking the waiting list and free sessions.
119struct Sessions {
120    available_sessions: VecDeque<SessionHandle>,
121
122    waiters: VecDeque<oneshot::Sender<()>>,
123
124    /// Invalid sessions living in the server.
125    orphans: Vec<SessionHandle>,
126
127    /// number of sessions user uses.
128    num_inuse: usize,
129
130    /// number of sessions scheduled to be replenished.
131    num_creating: usize,
132
133    /// Maximum observed number of sessions in use during the current window.
134    max_inuse_window: usize,
135    /// Start of the rolling window used for `max_inuse_window`.
136    window_started_at: Instant,
137}
138
139impl Sessions {
140    fn num_opened(&self) -> usize {
141        self.num_inuse + self.available_sessions.len()
142    }
143
144    fn take_waiter(&mut self) -> Option<oneshot::Sender<()>> {
145        while let Some(waiter) = self.waiters.pop_front() {
146            // Waiter can be closed when session acquisition times out.
147            if !waiter.is_closed() {
148                return Some(waiter);
149            }
150        }
151        None
152    }
153
154    fn take(&mut self) -> Option<SessionHandle> {
155        match self.available_sessions.pop_front() {
156            None => None,
157            Some(s) => {
158                self.num_inuse += 1;
159                self.update_max_in_use();
160                Some(s)
161            }
162        }
163    }
164
165    fn release(&mut self, session: SessionHandle) {
166        if self.num_inuse > 0 {
167            self.num_inuse -= 1;
168        }
169        if session.valid {
170            self.available_sessions.push_back(session);
171        } else if !session.deleted {
172            tracing::trace!("save as orphan name={}", session.session.name);
173            self.orphans.push(session);
174        }
175    }
176
177    /// reserve calculates next session count to create.
178    /// Must call replenish after calling this method.
179    fn reserve(&mut self, max_opened: usize, inc_step: usize) -> usize {
180        let num_opened = self.num_opened();
181        let num_creating = self.num_creating;
182        if max_opened < num_creating + num_opened {
183            tracing::trace!(
184                "No available connections max={}, num_creating={}, current={}",
185                max_opened,
186                num_creating,
187                num_opened
188            );
189            return 0;
190        }
191        let mut increasing = max_opened - (num_creating + num_opened);
192        if increasing > inc_step {
193            increasing = inc_step
194        }
195        self.num_creating += increasing;
196        increasing
197    }
198
199    fn replenish(&mut self, session_count: usize, result: Result<Vec<SessionHandle>, Status>) {
200        self.num_creating -= session_count;
201        match result {
202            Ok(mut new_sessions) => {
203                while let Some(session) = new_sessions.pop() {
204                    self.available_sessions.push_back(session);
205                    if let Some(waiter) = self.take_waiter() {
206                        let _ = waiter.send(());
207                    }
208                }
209            }
210            Err(e) => tracing::error!("failed to create new sessions {:?}", e),
211        }
212    }
213
214    fn update_max_in_use(&mut self) {
215        let now = Instant::now();
216        if now.duration_since(self.window_started_at) >= MAX_IN_USE_WINDOW {
217            self.window_started_at = now;
218            self.max_inuse_window = self.num_inuse;
219        } else if self.num_inuse > self.max_inuse_window {
220            self.max_inuse_window = self.num_inuse;
221        }
222    }
223}
224
225#[derive(Clone)]
226struct SessionPool {
227    inner: Arc<RwLock<Sessions>>,
228    session_creation_sender: UnboundedSender<usize>,
229    config: Arc<SessionConfig>,
230    metrics: Arc<MetricsRecorder>,
231}
232
233impl SessionPool {
234    async fn new(
235        database: String,
236        conn_pool: &ConnectionManager,
237        session_creation_sender: UnboundedSender<usize>,
238        config: Arc<SessionConfig>,
239        disable_route_to_leader: bool,
240        metrics: Arc<MetricsRecorder>,
241    ) -> Result<Self, Status> {
242        let available_sessions =
243            Self::init_pool(database, conn_pool, config.min_opened, disable_route_to_leader, metrics.clone()).await?;
244        let pool = SessionPool {
245            inner: Arc::new(RwLock::new(Sessions {
246                available_sessions,
247                waiters: VecDeque::new(),
248                orphans: Vec::new(),
249                num_inuse: 0,
250                num_creating: 0,
251                max_inuse_window: 0,
252                window_started_at: Instant::now(),
253            })),
254            session_creation_sender,
255            config,
256            metrics,
257        };
258        pool.metrics.register_session_pool(pool.snapshot_fn());
259        Ok(pool)
260    }
261
262    async fn init_pool(
263        database: String,
264        conn_pool: &ConnectionManager,
265        min_opened: usize,
266        disable_route_to_leader: bool,
267        metrics: Arc<MetricsRecorder>,
268    ) -> Result<VecDeque<SessionHandle>, Status> {
269        let channel_num = conn_pool.num();
270        let creation_count_per_channel = min_opened / channel_num;
271        let remainder = min_opened % channel_num;
272
273        let mut sessions = Vec::<SessionHandle>::new();
274        let mut tasks = JoinSet::new();
275        for i in 0..channel_num {
276            // Ensure that we create the exact number of requested sessions by adding the remainder to the first channel.
277            let creation_count = if i == 0 {
278                creation_count_per_channel + remainder
279            } else {
280                creation_count_per_channel
281            };
282            let next_client = conn_pool
283                .conn()
284                .with_metrics(metrics.clone())
285                .with_metadata(client_metadata(&database));
286            let database = database.clone();
287            tasks.spawn(async move {
288                batch_create_sessions(next_client, &database, creation_count, disable_route_to_leader).await
289            }.instrument(Span::current()));
290        }
291        while let Some(r) = tasks.join_next().await {
292            let new_sessions = r.map_err(|e| Status::from_error(e.into()))??;
293            sessions.extend(new_sessions);
294        }
295        tracing::debug!("initial session created count = {}", sessions.len());
296        Ok(sessions.into())
297    }
298
299    fn num_opened(&self) -> usize {
300        self.inner.read().num_opened()
301    }
302
303    /// The client first checks the waiting list.
304    /// If the waiting list is empty, it retrieves the first available session.
305    /// If there are no available sessions, it enters the waiting list.
306    /// If the waiting list is not empty, the client enters the waiting list.
307    /// The client on the waiting list will be notified when another client's session has finished and
308    /// when the process of replenishing the available sessions is complete.
309    async fn acquire(&self) -> Result<ManagedSession, SessionError> {
310        let request_started_at = Instant::now();
311        loop {
312            let (on_session_acquired, session_count) = {
313                let mut sessions = self.inner.write();
314
315                // Prioritize waiters over new acquirers.
316                if sessions.waiters.is_empty() {
317                    if let Some(mut s) = sessions.take() {
318                        s.last_used_at = Instant::now();
319                        self.metrics.record_session_acquired();
320                        self.metrics
321                            .record_session_acquire_latency(request_started_at.elapsed());
322                        return Ok(ManagedSession::new(self.clone(), s));
323                    }
324                }
325                // Add the participant to the waiting list.
326                let (sender, receiver) = oneshot::channel();
327                sessions.waiters.push_back(sender);
328                let session_count = sessions.reserve(self.config.max_opened, self.config.inc_step);
329                (receiver, session_count)
330            };
331
332            if session_count > 0 {
333                let _ = self.session_creation_sender.send(session_count);
334            }
335
336            // Wait for the session available notification.
337            match timeout(self.config.session_get_timeout, on_session_acquired).await {
338                Ok(Ok(())) => {
339                    let mut sessions = self.inner.write();
340                    if let Some(mut s) = sessions.take() {
341                        s.last_used_at = Instant::now();
342                        self.metrics.record_session_acquired();
343                        self.metrics
344                            .record_session_acquire_latency(request_started_at.elapsed());
345                        return Ok(ManagedSession::new(self.clone(), s));
346                    } else {
347                        continue; // another waiter raced for session
348                    }
349                }
350                _ => {
351                    {
352                        let sessions = self.inner.write();
353                        tracing::info!(
354                            available = sessions.available_sessions.len(),
355                            waiters = sessions.waiters.len(),
356                            orphans = sessions.orphans.len(),
357                            num_inuse = sessions.num_inuse,
358                            num_creating = sessions.num_creating,
359                            max_opened = self.config.max_opened,
360                            "Timeout acquiring session"
361                        );
362                    }
363                    self.metrics.record_session_timeout();
364                    return Err(SessionError::SessionGetTimeout);
365                }
366            }
367        }
368    }
369
370    /// If the session is valid
371    ///  - Pass the session to the first user on the waiting list.
372    ///  - If there is no waiting list, the session is returned to the list of available sessions.
373    ///    If the session is invalid
374    ///  - Discard the session. If the number of sessions falls below the threshold as a result of discarding, the session replenishment process is called.
375    fn recycle(&self, mut session: SessionHandle) {
376        self.metrics.record_session_released();
377        if session.valid {
378            let mut sessions = self.inner.write();
379            let waiter = sessions.take_waiter();
380            if sessions.num_opened() > self.config.max_idle
381                && session.created_at + self.config.idle_timeout < Instant::now()
382                && waiter.is_none()
383            {
384                // Not reuse expired idle session
385                session.valid = false
386            }
387            sessions.release(session);
388            if let Some(waiter) = waiter {
389                let _ = waiter.send(());
390            }
391        } else {
392            let session_count = {
393                let mut sessions = self.inner.write();
394                sessions.release(session);
395                if sessions.num_opened() < self.config.min_opened && !sessions.waiters.is_empty() {
396                    sessions.reserve(self.config.max_opened, self.config.inc_step)
397                } else {
398                    0
399                }
400            };
401            if session_count > 0 {
402                let _ = self.session_creation_sender.send(session_count);
403            }
404        }
405    }
406
407    async fn close(&self) {
408        let empty = VecDeque::new();
409        let deleting_sessions = { mem::replace(&mut self.inner.write().available_sessions, empty) };
410        for mut session in deleting_sessions {
411            session.delete().await;
412        }
413
414        self.remove_orphans().await;
415    }
416
417    fn snapshot_fn(&self) -> SessionPoolStatsFn {
418        let inner = self.inner.clone();
419        let max_allowed = self.config.max_opened;
420        Arc::new(move || {
421            let sessions = inner.read();
422            SessionPoolSnapshot {
423                open_sessions: sessions.num_opened(),
424                sessions_in_use: sessions.num_inuse,
425                idle_sessions: sessions.available_sessions.len(),
426                max_allowed_sessions: max_allowed,
427                max_in_use_last_window: sessions.max_inuse_window,
428                has_multiplexed_session: false,
429            }
430        })
431    }
432
433    async fn remove_orphans(&self) {
434        let empty = vec![];
435        let deleting_sessions = { mem::replace(&mut self.inner.write().orphans, empty) };
436        tracing::trace!("remove {} orphan sessions", deleting_sessions.len());
437        for mut session in deleting_sessions {
438            session.delete().await;
439        }
440    }
441}
442
443#[derive(Clone, Debug)]
444pub struct SessionConfig {
445    /// max_opened is the maximum number of opened sessions allowed by the session
446    /// pool. If the client tries to open a session and there are already
447    /// max_opened sessions, it will block until one becomes available or the
448    /// context passed to the client method is canceled or times out.
449    pub max_opened: usize,
450
451    /// min_opened is the minimum number of opened sessions that the session pool
452    /// tries to maintain. Session pool won't continue to expire sessions if
453    /// number of opened connections drops below min_opened. However, if a session
454    /// is found to be broken, it will still be evicted from the session pool,
455    /// therefore it is posssible that the number of opened sessions drops below
456    /// min_opened.
457    pub min_opened: usize,
458
459    /// max_idle is the maximum number of idle sessions, pool is allowed to keep.
460    pub max_idle: usize,
461
462    /// idle_timeout is the wait time before discarding an idle session.
463    /// Sessions older than this value since they were last used will be discarded.
464    /// However, if the number of sessions is less than or equal to min_opened, it will not be discarded.
465    pub idle_timeout: Duration,
466
467    pub session_alive_trust_duration: Duration,
468
469    /// session_get_timeout is the maximum value of the waiting time that occurs when retrieving from the connection pool when there is no idle session.
470    pub session_get_timeout: Duration,
471
472    /// refresh_interval is the interval of cleanup and health check functions.
473    pub refresh_interval: Duration,
474
475    /// incStep is the number of sessions to create in one batch when at least
476    /// one more session is needed.
477    inc_step: usize,
478}
479
480impl Default for SessionConfig {
481    fn default() -> Self {
482        SessionConfig {
483            max_opened: 400,
484            min_opened: 10,
485            max_idle: 300,
486            inc_step: 25,
487            idle_timeout: Duration::from_secs(30 * 60),
488            session_alive_trust_duration: Duration::from_secs(55 * 60),
489            session_get_timeout: Duration::from_secs(1),
490            refresh_interval: Duration::from_secs(5 * 60),
491        }
492    }
493}
494
495#[derive(thiserror::Error, Debug)]
496pub enum SessionError {
497    #[error("session get time out")]
498    SessionGetTimeout,
499    #[error("failed to create session")]
500    FailedToCreateSession,
501    #[error(transparent)]
502    GRPC(#[from] Status),
503}
504
505impl TryAs<Status> for SessionError {
506    fn try_as(&self) -> Option<&Status> {
507        match self {
508            SessionError::GRPC(e) => Some(e),
509            _ => None,
510        }
511    }
512}
513
514pub(crate) struct SessionManager {
515    session_pool: SessionPool,
516    cancel: CancellationToken,
517    tasks: Mutex<Vec<JoinHandle<()>>>,
518}
519
520impl SessionManager {
521    pub async fn new(
522        database: impl Into<String>,
523        conn_pool: ConnectionManager,
524        config: SessionConfig,
525        disable_route_to_leader: bool,
526        metrics: Arc<MetricsRecorder>,
527    ) -> Result<Arc<SessionManager>, Status> {
528        let database = database.into();
529        let (sender, receiver) = mpsc::unbounded_channel();
530        let session_pool = SessionPool::new(
531            database.clone(),
532            &conn_pool,
533            sender,
534            Arc::new(config.clone()),
535            disable_route_to_leader,
536            metrics.clone(),
537        )
538        .await?;
539
540        let cancel = CancellationToken::new();
541        let task_session_cleaner = Self::spawn_health_check_task(config, session_pool.clone(), cancel.clone());
542        let task_session_creator = Self::spawn_session_creation_task(
543            session_pool.clone(),
544            database,
545            conn_pool,
546            receiver,
547            cancel.clone(),
548            disable_route_to_leader,
549        );
550
551        let sm = SessionManager {
552            session_pool,
553            cancel,
554            tasks: Mutex::new(vec![task_session_cleaner, task_session_creator]),
555        };
556        Ok(Arc::new(sm))
557    }
558
559    pub fn num_opened(&self) -> usize {
560        self.session_pool.num_opened()
561    }
562
563    pub async fn get(&self) -> Result<ManagedSession, SessionError> {
564        self.session_pool.acquire().await
565    }
566
567    pub async fn close(&self) {
568        if self.cancel.is_cancelled() {
569            return;
570        }
571        self.cancel.cancel();
572        let tasks = { mem::take(&mut *self.tasks.lock()) };
573        for task in tasks {
574            let _ = task.await;
575        }
576        self.session_pool.close().await;
577    }
578
579    fn spawn_session_creation_task(
580        session_pool: SessionPool,
581        database: String,
582        conn_pool: ConnectionManager,
583        mut rx: UnboundedReceiver<usize>,
584        cancel: CancellationToken,
585        disable_route_to_leader: bool,
586    ) -> JoinHandle<()> {
587        tokio::spawn(async move {
588            let mut tasks = JoinSet::default();
589            loop {
590                select! {
591                    biased;
592                    _ = cancel.cancelled() => break,
593                    Some(Ok((session_count, result))) = tasks.join_next(), if !tasks.is_empty() => {
594                        session_pool.inner.write().replenish(session_count, result);
595                    }
596                    session_count = rx.recv() => match session_count {
597                        Some(session_count) => {
598                            let client = conn_pool
599                                .conn()
600                                .with_metrics(session_pool.metrics.clone())
601                                .with_metadata(client_metadata(&database));
602                            let database = database.clone();
603                            tasks.spawn(async move { (session_count, batch_create_sessions(client, &database, session_count, disable_route_to_leader).await) });
604                        },
605                        None => continue
606                    },
607                }
608            }
609            tracing::trace!("shutdown session creation task.");
610        })
611    }
612
613    fn spawn_health_check_task(
614        config: SessionConfig,
615        session_pool: SessionPool,
616        cancel: CancellationToken,
617    ) -> JoinHandle<()> {
618        let start = Instant::now() + config.refresh_interval;
619        let mut interval = tokio::time::interval_at(start.into(), config.refresh_interval);
620
621        tokio::spawn(async move {
622            loop {
623                select! {
624                    _ = interval.tick() => {},
625                    _ = cancel.cancelled() => break
626                }
627                let now = Instant::now();
628
629                // remove orphans first
630                session_pool.remove_orphans().await;
631
632                // start health check
633                health_check(
634                    now + Duration::from_nanos(1),
635                    config.session_alive_trust_duration,
636                    &session_pool,
637                    cancel.clone(),
638                )
639                .await;
640            }
641            tracing::trace!("shutdown health check task.")
642        })
643    }
644}
645
646async fn health_check(
647    now: Instant,
648    session_alive_trust_duration: Duration,
649    sessions: &SessionPool,
650    cancel: CancellationToken,
651) {
652    tracing::trace!("start health check");
653    let start = Instant::now();
654    let sleep_duration = Duration::from_millis(10);
655    loop {
656        select! {
657            _ = sleep(sleep_duration) => {},
658            _ = cancel.cancelled() => break
659        }
660        let mut s = {
661            // temporary take
662            let mut locked = sessions.inner.write();
663            match locked.take() {
664                Some(mut s) => {
665                    // all the session check complete.
666                    if s.last_checked_at == now {
667                        locked.release(s);
668                        break;
669                    }
670                    if std::cmp::max(s.last_used_at, s.last_pong_at) + session_alive_trust_duration >= now {
671                        s.last_checked_at = now;
672                        locked.release(s);
673                        continue;
674                    }
675                    s
676                }
677                None => break,
678            }
679        };
680
681        let request = ping_query_request(s.session.name.clone());
682        match s.spanner_client.execute_sql(request, true, None).await {
683            Ok(_) => {
684                s.last_checked_at = now;
685                s.last_pong_at = now;
686                sessions.recycle(s);
687            }
688            Err(_) => {
689                s.delete().await;
690                sessions.recycle(s);
691            }
692        }
693    }
694    tracing::trace!("end health check elapsed={}msec", start.elapsed().as_millis());
695}
696
697async fn batch_create_sessions(
698    spanner_client: Client,
699    database: &str,
700    mut remaining_create_count: usize,
701    disable_route_to_leader: bool,
702) -> Result<Vec<SessionHandle>, Status> {
703    let mut created = Vec::with_capacity(remaining_create_count);
704    while remaining_create_count > 0 {
705        let sessions = batch_create_session(
706            spanner_client.clone(),
707            database,
708            remaining_create_count,
709            disable_route_to_leader,
710        )
711        .await?;
712        // Spanner could return less sessions than requested.
713        // In that case, we should do another call using the same gRPC channel.
714        let actually_created = sessions.len();
715        remaining_create_count -= actually_created;
716        created.extend(sessions);
717    }
718    Ok(created)
719}
720
721async fn batch_create_session(
722    mut spanner_client: Client,
723    database: &str,
724    session_count: usize,
725    disable_route_to_leader: bool,
726) -> Result<Vec<SessionHandle>, Status> {
727    let request = BatchCreateSessionsRequest {
728        database: database.to_string(),
729        session_template: None,
730        session_count: session_count as i32,
731    };
732
733    tracing::debug!("spawn session creation request : session_count = {}", session_count);
734    let response = spanner_client
735        .batch_create_sessions(request, disable_route_to_leader, None)
736        .await?
737        .into_inner();
738
739    let now = Instant::now();
740    Ok(response
741        .session
742        .into_iter()
743        .map(|s| SessionHandle::new(s, spanner_client.clone(), now))
744        .collect::<Vec<SessionHandle>>())
745}
746
747pub(crate) fn client_metadata(database: &str) -> MetadataMap {
748    let mut metadata = MetadataMap::new();
749    metadata.insert("google-cloud-resource-prefix", database.parse().unwrap());
750    metadata
751}
752
753#[cfg(test)]
754mod tests {
755    use std::sync::atomic::{AtomicI64, Ordering};
756    use std::sync::Arc;
757    use std::time::{Duration, Instant};
758
759    use parking_lot::RwLock;
760    use serial_test::serial;
761    use tokio::time::sleep;
762    use tokio_util::sync::CancellationToken;
763
764    use google_cloud_gax::conn::{ConnectionOptions, Environment};
765    use google_cloud_googleapis::spanner::v1::ExecuteSqlRequest;
766
767    use crate::apiv1::conn_pool::ConnectionManager;
768    use crate::metrics::MetricsRecorder;
769    use crate::session::{
770        batch_create_sessions, client_metadata, health_check, SessionConfig, SessionError, SessionManager,
771    };
772
773    pub const DATABASE: &str = "projects/local-project/instances/test-instance/databases/local-database";
774
775    #[ctor::ctor]
776    fn init() {
777        let filter = tracing_subscriber::filter::EnvFilter::from_default_env()
778            .add_directive("google_cloud_spanner=trace".parse().unwrap());
779        let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
780    }
781
782    async fn assert_rush(use_invalidate: bool, config: SessionConfig) -> Arc<SessionManager> {
783        let cm = ConnectionManager::new(
784            4,
785            &Environment::Emulator("localhost:9010".to_string()),
786            "",
787            &ConnectionOptions::default(),
788        )
789        .await
790        .unwrap();
791        let sm = SessionManager::new(DATABASE, cm, config, false, Arc::new(MetricsRecorder::default()))
792            .await
793            .unwrap();
794
795        let counter = Arc::new(AtomicI64::new(0));
796        let mut spawns = Vec::with_capacity(100);
797        for _ in 0..100 {
798            let sm = sm.clone();
799            let counter = Arc::clone(&counter);
800            spawns.push(tokio::spawn(async move {
801                let mut session = sm.get().await.unwrap();
802                if use_invalidate {
803                    session.delete().await;
804                }
805                counter.fetch_add(1, Ordering::SeqCst);
806                sleep(Duration::from_millis(300)).await;
807            }));
808        }
809        for handler in spawns {
810            let _ = handler.await;
811        }
812        sm
813    }
814
815    #[tokio::test(flavor = "multi_thread")]
816    #[serial]
817    async fn test_health_check_checked() {
818        let cm = ConnectionManager::new(
819            4,
820            &Environment::Emulator("localhost:9010".to_string()),
821            "",
822            &ConnectionOptions::default(),
823        )
824        .await
825        .unwrap();
826        let session_alive_trust_duration = Duration::from_millis(10);
827        let config = SessionConfig {
828            min_opened: 5,
829            session_alive_trust_duration,
830            max_opened: 5,
831            ..Default::default()
832        };
833        let sm = std::sync::Arc::new(
834            SessionManager::new(DATABASE, cm, config, false, Arc::new(MetricsRecorder::default()))
835                .await
836                .unwrap(),
837        );
838        sleep(Duration::from_secs(1)).await;
839
840        let cancel = CancellationToken::new();
841        health_check(Instant::now(), session_alive_trust_duration, &sm.session_pool, cancel.clone()).await;
842
843        assert_eq!(sm.num_opened(), 5);
844        tokio::time::sleep(Duration::from_millis(500)).await;
845        cancel.cancel();
846    }
847
848    #[tokio::test(flavor = "multi_thread")]
849    #[serial]
850    async fn test_health_check_not_checked() {
851        let cm = ConnectionManager::new(
852            4,
853            &Environment::Emulator("localhost:9010".to_string()),
854            "",
855            &ConnectionOptions::default(),
856        )
857        .await
858        .unwrap();
859        let session_alive_trust_duration = Duration::from_secs(10);
860        let config = SessionConfig {
861            min_opened: 5,
862            session_alive_trust_duration,
863            max_opened: 5,
864            ..Default::default()
865        };
866        let sm = Arc::new(
867            SessionManager::new(DATABASE, cm, config, false, Arc::new(MetricsRecorder::default()))
868                .await
869                .unwrap(),
870        );
871        sleep(Duration::from_secs(1)).await;
872
873        let cancel = CancellationToken::new();
874        health_check(Instant::now(), session_alive_trust_duration, &sm.session_pool, cancel.clone()).await;
875
876        assert_eq!(sm.num_opened(), 5);
877        sleep(Duration::from_millis(500)).await;
878        cancel.cancel();
879    }
880
881    #[tokio::test(flavor = "multi_thread")]
882    #[serial]
883    async fn test_increase_session_and_idle_session_expired() {
884        let conn_pool = ConnectionManager::new(
885            4,
886            &Environment::Emulator("localhost:9010".to_string()),
887            "",
888            &ConnectionOptions::default(),
889        )
890        .await
891        .unwrap();
892        let config = SessionConfig {
893            idle_timeout: Duration::from_millis(10),
894            min_opened: 10,
895            max_idle: 20,
896            max_opened: 45,
897            ..Default::default()
898        };
899        let sm = SessionManager::new(DATABASE, conn_pool, config, false, Arc::new(MetricsRecorder::default()))
900            .await
901            .unwrap();
902        {
903            let mut sessions = Vec::new();
904            for _ in 0..45 {
905                sessions.push(sm.get().await.unwrap());
906            }
907
908            // all the session are using
909            assert_eq!(sm.num_opened(), 45);
910            assert_eq!(sm.session_pool.inner.read().num_inuse, 45, "all the session are using");
911            sleep(Duration::from_secs(1)).await;
912        }
913
914        // idle session removed after drop
915        let sessions = sm.session_pool.inner.read();
916        assert_eq!(sessions.num_inuse, 0, "invalid num_inuse");
917        assert_eq!(sessions.available_sessions.len(), 20, "invalid available sessions");
918        assert_eq!(sessions.num_opened(), 20, "invalid num open");
919        assert_eq!(sessions.waiters.len(), 0, "session waiters is 0");
920    }
921
922    #[tokio::test(flavor = "multi_thread")]
923    #[serial]
924    async fn test_too_many_session_timeout() {
925        let conn_pool = ConnectionManager::new(
926            4,
927            &Environment::Emulator("localhost:9010".to_string()),
928            "",
929            &ConnectionOptions::default(),
930        )
931        .await
932        .unwrap();
933        let config = SessionConfig {
934            idle_timeout: Duration::from_millis(10),
935            min_opened: 10,
936            max_idle: 20,
937            max_opened: 45,
938            session_get_timeout: Duration::from_secs(1),
939            ..Default::default()
940        };
941        let sm = Arc::new(
942            SessionManager::new(DATABASE, conn_pool, config.clone(), false, Arc::new(MetricsRecorder::default()))
943                .await
944                .unwrap(),
945        );
946        let mu = Arc::new(RwLock::new(Vec::new()));
947        let mut awaiters = Vec::with_capacity(100);
948        for _ in 0..100 {
949            let sm = sm.clone();
950            let mu = mu.clone();
951            awaiters.push(tokio::spawn(async move {
952                let session = sm.get().await;
953                mu.write().push(session);
954                0
955            }))
956        }
957        for handler in awaiters {
958            let _ = handler.await;
959        }
960        let sessions = mu.read();
961        for i in 0..sessions.len() - 1 {
962            let session = &sessions[i];
963            if i >= config.max_opened {
964                assert!(session.is_err(), "must err {i}");
965                match session.as_ref().err().unwrap() {
966                    SessionError::SessionGetTimeout => {}
967                    _ => {
968                        panic!("must be session timeout error")
969                    }
970                }
971            } else {
972                assert!(session.is_ok(), "must ok {i}");
973            }
974        }
975        let pool = sm.session_pool.inner.read();
976        assert_eq!(pool.num_opened(), config.max_opened);
977        assert_eq!(pool.waiters.len(), 100 - config.max_opened); //include timeout sessions
978    }
979
980    #[tokio::test(flavor = "multi_thread")]
981    #[serial]
982    async fn test_rush_invalidate() {
983        let config = SessionConfig {
984            session_get_timeout: Duration::from_secs(20),
985            min_opened: 10,
986            max_idle: 20,
987            max_opened: 45,
988            ..Default::default()
989        };
990        let sm = assert_rush(true, config.clone()).await;
991        {
992            let sessions = sm.session_pool.inner.read();
993            let available_sessions = sessions.available_sessions.len();
994            assert_eq!(sessions.num_inuse, 0);
995            assert_eq!(sessions.waiters.len(), 0);
996            assert_eq!(sessions.orphans.len(), 0);
997            assert!(
998                available_sessions <= config.max_opened && available_sessions >= config.min_opened,
999                "now is {available_sessions}"
1000            );
1001        }
1002        sm.close().await;
1003    }
1004
1005    #[tokio::test(flavor = "multi_thread")]
1006    #[serial]
1007    async fn test_rush() {
1008        let config = SessionConfig {
1009            min_opened: 10,
1010            max_idle: 20,
1011            max_opened: 45,
1012            ..Default::default()
1013        };
1014        let sm = assert_rush(false, config.clone()).await;
1015        {
1016            let sessions = sm.session_pool.inner.read();
1017            let available_sessions = sessions.available_sessions.len();
1018            assert_eq!(sessions.num_inuse, 0);
1019            assert_eq!(sessions.waiters.len(), 0);
1020            assert_eq!(sessions.orphans.len(), 0);
1021            assert!(
1022                available_sessions <= config.max_opened && available_sessions >= config.min_opened,
1023                "now is {available_sessions}"
1024            );
1025        }
1026        sm.close().await;
1027    }
1028
1029    #[tokio::test(flavor = "multi_thread")]
1030    #[serial]
1031    async fn test_rush_with_invalidate() {
1032        let config = SessionConfig {
1033            min_opened: 10,
1034            max_idle: 20,
1035            max_opened: 45,
1036            ..Default::default()
1037        };
1038        let sm = assert_rush(true, config.clone()).await;
1039        {
1040            let sessions = sm.session_pool.inner.read();
1041            let available_sessions = sessions.available_sessions.len();
1042            assert_eq!(sessions.num_inuse, 0);
1043            assert_eq!(sessions.waiters.len(), 0);
1044            assert_eq!(sessions.orphans.len(), 0);
1045            assert!(
1046                available_sessions <= config.max_opened && available_sessions >= config.min_opened,
1047                "now is {available_sessions}"
1048            );
1049        }
1050        sm.close().await;
1051    }
1052
1053    #[tokio::test(flavor = "multi_thread")]
1054    #[serial]
1055    async fn test_rush_with_health_check() {
1056        let config = SessionConfig {
1057            session_alive_trust_duration: Duration::from_millis(10),
1058            refresh_interval: Duration::from_millis(250),
1059            session_get_timeout: Duration::from_secs(20),
1060            min_opened: 10,
1061            max_idle: 20,
1062            max_opened: 45,
1063            ..Default::default()
1064        };
1065        let sm = assert_rush(false, config.clone()).await;
1066        sleep(Duration::from_secs(2)).await;
1067        {
1068            let sessions = sm.session_pool.inner.read();
1069            let available_sessions = sessions.available_sessions.len();
1070            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1071            assert_eq!(sessions.waiters.len(), 0);
1072            assert_eq!(sessions.orphans.len(), 0);
1073            assert!(
1074                available_sessions <= config.max_opened && available_sessions >= config.max_idle - 1,
1075                "now is {available_sessions}"
1076            );
1077        }
1078        sm.close().await;
1079    }
1080
1081    #[tokio::test(flavor = "multi_thread")]
1082    #[serial]
1083    async fn test_rush_with_health_check_and_invalidate() {
1084        let config = SessionConfig {
1085            session_alive_trust_duration: Duration::from_millis(10),
1086            refresh_interval: Duration::from_millis(250),
1087            session_get_timeout: Duration::from_secs(20),
1088            min_opened: 10,
1089            max_idle: 20,
1090            max_opened: 45,
1091            ..Default::default()
1092        };
1093        let sm = assert_rush(true, config.clone()).await;
1094        sleep(Duration::from_secs(2)).await;
1095        {
1096            let sessions = sm.session_pool.inner.read();
1097            let available_sessions = sessions.available_sessions.len();
1098            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1099            assert_eq!(sessions.waiters.len(), 0);
1100            assert_eq!(sessions.orphans.len(), 0);
1101            assert!(
1102                available_sessions <= config.max_opened && available_sessions >= config.min_opened - 1,
1103                "now is {available_sessions}"
1104            );
1105        }
1106        sm.close().await;
1107    }
1108
1109    #[tokio::test(flavor = "multi_thread")]
1110    #[serial]
1111    async fn test_rush_with_idle_expired() {
1112        let config = SessionConfig {
1113            min_opened: 10,
1114            max_idle: 20,
1115            max_opened: 45,
1116            idle_timeout: Duration::from_millis(1),
1117            ..Default::default()
1118        };
1119        let sm = assert_rush(false, config.clone()).await;
1120        {
1121            let sessions = sm.session_pool.inner.read();
1122            assert_eq!(sessions.num_inuse, 0);
1123            assert_eq!(sessions.waiters.len(), 0);
1124            assert_eq!(sessions.orphans.len(), config.max_opened - config.max_idle);
1125            assert_eq!(sessions.available_sessions.len(), config.max_idle);
1126        }
1127        sm.close().await;
1128    }
1129
1130    #[tokio::test(flavor = "multi_thread")]
1131    #[serial]
1132    async fn test_rush_with_health_check_and_idle_expired() {
1133        let config = SessionConfig {
1134            session_alive_trust_duration: Duration::from_millis(10),
1135            refresh_interval: Duration::from_millis(250),
1136            session_get_timeout: Duration::from_secs(20),
1137            min_opened: 10,
1138            max_idle: 20,
1139            max_opened: 45,
1140            idle_timeout: Duration::from_millis(1),
1141            ..Default::default()
1142        };
1143        let sm = assert_rush(false, config.clone()).await;
1144        sleep(Duration::from_secs(1)).await;
1145        {
1146            let sessions = sm.session_pool.inner.read();
1147            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1148            assert_eq!(sessions.waiters.len(), 0);
1149            assert_eq!(sessions.orphans.len(), 0);
1150            let available_sessions = sessions.available_sessions.len();
1151            assert!(
1152                available_sessions >= config.min_opened - 1 && available_sessions <= config.max_idle,
1153                "now is {available_sessions}"
1154            );
1155        }
1156        sm.close().await;
1157    }
1158
1159    #[tokio::test(flavor = "multi_thread")]
1160    #[serial]
1161    async fn test_rush_with_health_check_and_idle_expired_and_invalid() {
1162        let config = SessionConfig {
1163            session_alive_trust_duration: Duration::from_millis(10),
1164            refresh_interval: Duration::from_millis(250),
1165            session_get_timeout: Duration::from_secs(20),
1166            min_opened: 10,
1167            max_idle: 20,
1168            max_opened: 45,
1169            idle_timeout: Duration::from_millis(1),
1170            ..Default::default()
1171        };
1172        let sm = assert_rush(true, config.clone()).await;
1173        sleep(Duration::from_secs(2)).await;
1174        {
1175            let sessions = sm.session_pool.inner.read();
1176            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1177            // health checker removes orphans
1178            assert_eq!(sessions.orphans.len(), 0);
1179            assert_eq!(sessions.waiters.len(), 0, "invalid waiters");
1180            let available_sessions = sessions.available_sessions.len();
1181            assert!(
1182                available_sessions >= config.min_opened - 1 && available_sessions <= config.max_idle,
1183                "now is {available_sessions}"
1184            );
1185        }
1186        sm.close().await;
1187    }
1188
1189    #[tokio::test(flavor = "multi_thread")]
1190    #[serial]
1191    async fn test_close() {
1192        let cm = ConnectionManager::new(
1193            4,
1194            &Environment::Emulator("localhost:9010".to_string()),
1195            "",
1196            &ConnectionOptions::default(),
1197        )
1198        .await
1199        .unwrap();
1200        let config = SessionConfig::default();
1201        let sm = SessionManager::new(DATABASE, cm, config.clone(), false, Arc::new(MetricsRecorder::default()))
1202            .await
1203            .unwrap();
1204        assert_eq!(sm.num_opened(), config.min_opened);
1205        sm.close().await;
1206        assert_eq!(sm.num_opened(), 0);
1207        assert_eq!(sm.session_pool.inner.read().orphans.len(), 0);
1208    }
1209
1210    #[tokio::test(flavor = "multi_thread")]
1211    #[serial]
1212    async fn test_batch_create_sessions() {
1213        let cm = ConnectionManager::new(
1214            1,
1215            &Environment::Emulator("localhost:9010".to_string()),
1216            "",
1217            &ConnectionOptions::default(),
1218        )
1219        .await
1220        .unwrap();
1221        let client = cm
1222            .conn()
1223            .with_metrics(Arc::new(MetricsRecorder::default()))
1224            .with_metadata(client_metadata(DATABASE));
1225        let session_count = 125;
1226        let result = batch_create_sessions(client.clone(), DATABASE, session_count, false).await;
1227        match result {
1228            Ok(created) => {
1229                assert_eq!(session_count, created.len());
1230                for mut s in created {
1231                    let ping_result = s
1232                        .spanner_client
1233                        .execute_sql(
1234                            ExecuteSqlRequest {
1235                                session: s.session.name.to_string(),
1236                                transaction: None,
1237                                sql: "SELECT 1".to_string(),
1238                                params: None,
1239                                param_types: Default::default(),
1240                                resume_token: vec![],
1241                                query_mode: 0,
1242                                partition_token: vec![],
1243                                seqno: 0,
1244                                query_options: None,
1245                                request_options: None,
1246                                directed_read_options: None,
1247                                data_boost_enabled: false,
1248                                last_statement: false,
1249                            },
1250                            false,
1251                            None,
1252                        )
1253                        .await;
1254                    assert!(ping_result.is_ok());
1255                }
1256            }
1257            Err(err) => panic!("{err:?}"),
1258        }
1259    }
1260}