Skip to main content

gcloud_spanner/
session.rs

1use std::collections::VecDeque;
2use std::mem;
3use std::ops::{Deref, DerefMut};
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7use parking_lot::{Mutex, RwLock};
8use thiserror;
9use tokio::select;
10use tokio::sync::mpsc::{UnboundedReceiver, UnboundedSender};
11use tokio::sync::{mpsc, oneshot};
12use tokio::task::{JoinHandle, JoinSet};
13use tokio::time::{sleep, timeout};
14use tokio_util::sync::CancellationToken;
15
16use google_cloud_gax::grpc::metadata::MetadataMap;
17use google_cloud_gax::grpc::{Code, Status};
18use google_cloud_gax::retry::TryAs;
19use google_cloud_googleapis::spanner::v1::{BatchCreateSessionsRequest, DeleteSessionRequest, Session};
20
21use crate::apiv1::conn_pool::ConnectionManager;
22use crate::apiv1::spanner_client::{ping_query_request, Client};
23use crate::metrics::{MetricsRecorder, SessionPoolSnapshot, SessionPoolStatsFn};
24
25const MAX_IN_USE_WINDOW: Duration = Duration::from_secs(600);
26
27/// Session
28pub struct SessionHandle {
29    pub session: Session,
30    pub spanner_client: Client,
31    valid: bool,
32    deleted: bool,
33    last_used_at: Instant,
34    last_checked_at: Instant,
35    last_pong_at: Instant,
36    created_at: Instant,
37}
38
39impl SessionHandle {
40    pub(crate) fn new(session: Session, spanner_client: Client, now: Instant) -> SessionHandle {
41        SessionHandle {
42            session,
43            spanner_client,
44            valid: true,
45            deleted: false,
46            last_used_at: now,
47            last_checked_at: now,
48            last_pong_at: now,
49            created_at: now,
50        }
51    }
52
53    pub async fn invalidate_if_needed<T>(&mut self, arg: Result<T, Status>) -> Result<T, Status> {
54        match arg {
55            Ok(s) => Ok(s),
56            Err(e) => {
57                if e.code() == Code::NotFound && e.message().contains("Session not found:") {
58                    tracing::debug!("session invalidate {}", self.session.name);
59                    self.delete().await;
60                }
61                Err(e)
62            }
63        }
64    }
65
66    async fn delete(&mut self) {
67        self.valid = false;
68        let session_name = &self.session.name;
69        let request = DeleteSessionRequest {
70            name: session_name.to_string(),
71        };
72        match self.spanner_client.delete_session(request, true, None).await {
73            Ok(_) => self.deleted = true,
74            Err(e) => tracing::warn!("failed to delete session {}, {:?}", session_name, e),
75        };
76    }
77}
78
79/// ManagedSession
80pub struct ManagedSession {
81    session_pool: SessionPool,
82    session: Option<SessionHandle>,
83}
84
85impl ManagedSession {
86    fn new(session_pool: SessionPool, session: SessionHandle) -> Self {
87        ManagedSession {
88            session_pool,
89            session: Some(session),
90        }
91    }
92}
93
94impl Drop for ManagedSession {
95    fn drop(&mut self) {
96        let session = self.session.take().unwrap();
97        self.session_pool.recycle(session);
98    }
99}
100
101impl Deref for ManagedSession {
102    type Target = SessionHandle;
103
104    fn deref(&self) -> &Self::Target {
105        self.session.as_ref().unwrap()
106    }
107}
108
109impl DerefMut for ManagedSession {
110    fn deref_mut(&mut self) -> &mut Self::Target {
111        self.session.as_mut().unwrap()
112    }
113}
114
115/// Sessions have all sessions and waiters.
116/// This is for atomically locking the waiting list and free sessions.
117struct Sessions {
118    available_sessions: VecDeque<SessionHandle>,
119
120    waiters: VecDeque<oneshot::Sender<()>>,
121
122    /// Invalid sessions living in the server.
123    orphans: Vec<SessionHandle>,
124
125    /// number of sessions user uses.
126    num_inuse: usize,
127
128    /// number of sessions scheduled to be replenished.
129    num_creating: usize,
130
131    /// Maximum observed number of sessions in use during the current window.
132    max_inuse_window: usize,
133    /// Start of the rolling window used for `max_inuse_window`.
134    window_started_at: Instant,
135}
136
137impl Sessions {
138    fn num_opened(&self) -> usize {
139        self.num_inuse + self.available_sessions.len()
140    }
141
142    fn take_waiter(&mut self) -> Option<oneshot::Sender<()>> {
143        while let Some(waiter) = self.waiters.pop_front() {
144            // Waiter can be closed when session acquisition times out.
145            if !waiter.is_closed() {
146                return Some(waiter);
147            }
148        }
149        None
150    }
151
152    fn take(&mut self) -> Option<SessionHandle> {
153        match self.available_sessions.pop_front() {
154            None => None,
155            Some(s) => {
156                self.num_inuse += 1;
157                self.update_max_in_use();
158                Some(s)
159            }
160        }
161    }
162
163    fn release(&mut self, session: SessionHandle) {
164        if self.num_inuse > 0 {
165            self.num_inuse -= 1;
166        }
167        if session.valid {
168            self.available_sessions.push_back(session);
169        } else if !session.deleted {
170            tracing::trace!("save as orphan name={}", session.session.name);
171            self.orphans.push(session);
172        }
173    }
174
175    /// reserve calculates next session count to create.
176    /// Must call replenish after calling this method.
177    fn reserve(&mut self, max_opened: usize, inc_step: usize) -> usize {
178        let num_opened = self.num_opened();
179        let num_creating = self.num_creating;
180        if max_opened < num_creating + num_opened {
181            tracing::trace!(
182                "No available connections max={}, num_creating={}, current={}",
183                max_opened,
184                num_creating,
185                num_opened
186            );
187            return 0;
188        }
189        let mut increasing = max_opened - (num_creating + num_opened);
190        if increasing > inc_step {
191            increasing = inc_step
192        }
193        self.num_creating += increasing;
194        increasing
195    }
196
197    fn replenish(&mut self, session_count: usize, result: Result<Vec<SessionHandle>, Status>) {
198        self.num_creating -= session_count;
199        match result {
200            Ok(mut new_sessions) => {
201                while let Some(session) = new_sessions.pop() {
202                    self.available_sessions.push_back(session);
203                    if let Some(waiter) = self.take_waiter() {
204                        let _ = waiter.send(());
205                    }
206                }
207            }
208            Err(e) => tracing::error!("failed to create new sessions {:?}", e),
209        }
210    }
211
212    fn update_max_in_use(&mut self) {
213        let now = Instant::now();
214        if now.duration_since(self.window_started_at) >= MAX_IN_USE_WINDOW {
215            self.window_started_at = now;
216            self.max_inuse_window = self.num_inuse;
217        } else if self.num_inuse > self.max_inuse_window {
218            self.max_inuse_window = self.num_inuse;
219        }
220    }
221}
222
223#[derive(Clone)]
224struct SessionPool {
225    inner: Arc<RwLock<Sessions>>,
226    session_creation_sender: UnboundedSender<usize>,
227    config: Arc<SessionConfig>,
228    metrics: Arc<MetricsRecorder>,
229}
230
231impl SessionPool {
232    async fn new(
233        database: String,
234        conn_pool: &ConnectionManager,
235        session_creation_sender: UnboundedSender<usize>,
236        config: Arc<SessionConfig>,
237        disable_route_to_leader: bool,
238        metrics: Arc<MetricsRecorder>,
239    ) -> Result<Self, Status> {
240        let available_sessions =
241            Self::init_pool(database, conn_pool, config.min_opened, disable_route_to_leader, metrics.clone()).await?;
242        let pool = SessionPool {
243            inner: Arc::new(RwLock::new(Sessions {
244                available_sessions,
245                waiters: VecDeque::new(),
246                orphans: Vec::new(),
247                num_inuse: 0,
248                num_creating: 0,
249                max_inuse_window: 0,
250                window_started_at: Instant::now(),
251            })),
252            session_creation_sender,
253            config,
254            metrics,
255        };
256        pool.metrics.register_session_pool(pool.snapshot_fn());
257        Ok(pool)
258    }
259
260    async fn init_pool(
261        database: String,
262        conn_pool: &ConnectionManager,
263        min_opened: usize,
264        disable_route_to_leader: bool,
265        metrics: Arc<MetricsRecorder>,
266    ) -> Result<VecDeque<SessionHandle>, Status> {
267        let channel_num = conn_pool.num();
268        let creation_count_per_channel = min_opened / channel_num;
269        let remainder = min_opened % channel_num;
270
271        let mut sessions = Vec::<SessionHandle>::new();
272        let mut tasks = JoinSet::new();
273        for i in 0..channel_num {
274            // Ensure that we create the exact number of requested sessions by adding the remainder to the first channel.
275            let creation_count = if i == 0 {
276                creation_count_per_channel + remainder
277            } else {
278                creation_count_per_channel
279            };
280            let next_client = conn_pool
281                .conn()
282                .with_metrics(metrics.clone())
283                .with_metadata(client_metadata(&database));
284            let database = database.clone();
285            tasks.spawn(async move {
286                batch_create_sessions(next_client, &database, creation_count, disable_route_to_leader).await
287            });
288        }
289        while let Some(r) = tasks.join_next().await {
290            let new_sessions = r.map_err(|e| Status::from_error(e.into()))??;
291            sessions.extend(new_sessions);
292        }
293        tracing::debug!("initial session created count = {}", sessions.len());
294        Ok(sessions.into())
295    }
296
297    fn num_opened(&self) -> usize {
298        self.inner.read().num_opened()
299    }
300
301    /// The client first checks the waiting list.
302    /// If the waiting list is empty, it retrieves the first available session.
303    /// If there are no available sessions, it enters the waiting list.
304    /// If the waiting list is not empty, the client enters the waiting list.
305    /// The client on the waiting list will be notified when another client's session has finished and
306    /// when the process of replenishing the available sessions is complete.
307    async fn acquire(&self) -> Result<ManagedSession, SessionError> {
308        let request_started_at = Instant::now();
309        loop {
310            let (on_session_acquired, session_count) = {
311                let mut sessions = self.inner.write();
312
313                // Prioritize waiters over new acquirers.
314                if sessions.waiters.is_empty() {
315                    if let Some(mut s) = sessions.take() {
316                        s.last_used_at = Instant::now();
317                        self.metrics.record_session_acquired();
318                        self.metrics
319                            .record_session_acquire_latency(request_started_at.elapsed());
320                        return Ok(ManagedSession::new(self.clone(), s));
321                    }
322                }
323                // Add the participant to the waiting list.
324                let (sender, receiver) = oneshot::channel();
325                sessions.waiters.push_back(sender);
326                let session_count = sessions.reserve(self.config.max_opened, self.config.inc_step);
327                (receiver, session_count)
328            };
329
330            if session_count > 0 {
331                let _ = self.session_creation_sender.send(session_count);
332            }
333
334            // Wait for the session available notification.
335            match timeout(self.config.session_get_timeout, on_session_acquired).await {
336                Ok(Ok(())) => {
337                    let mut sessions = self.inner.write();
338                    if let Some(mut s) = sessions.take() {
339                        s.last_used_at = Instant::now();
340                        self.metrics.record_session_acquired();
341                        self.metrics
342                            .record_session_acquire_latency(request_started_at.elapsed());
343                        return Ok(ManagedSession::new(self.clone(), s));
344                    } else {
345                        continue; // another waiter raced for session
346                    }
347                }
348                _ => {
349                    {
350                        let sessions = self.inner.write();
351                        tracing::info!(
352                            available = sessions.available_sessions.len(),
353                            waiters = sessions.waiters.len(),
354                            orphans = sessions.orphans.len(),
355                            num_inuse = sessions.num_inuse,
356                            num_creating = sessions.num_creating,
357                            max_opened = self.config.max_opened,
358                            "Timeout acquiring session"
359                        );
360                    }
361                    self.metrics.record_session_timeout();
362                    return Err(SessionError::SessionGetTimeout);
363                }
364            }
365        }
366    }
367
368    /// If the session is valid
369    ///  - Pass the session to the first user on the waiting list.
370    ///  - If there is no waiting list, the session is returned to the list of available sessions.
371    ///    If the session is invalid
372    ///  - Discard the session. If the number of sessions falls below the threshold as a result of discarding, the session replenishment process is called.
373    fn recycle(&self, mut session: SessionHandle) {
374        self.metrics.record_session_released();
375        if session.valid {
376            let mut sessions = self.inner.write();
377            let waiter = sessions.take_waiter();
378            if sessions.num_opened() > self.config.max_idle
379                && session.created_at + self.config.idle_timeout < Instant::now()
380                && waiter.is_none()
381            {
382                // Not reuse expired idle session
383                session.valid = false
384            }
385            sessions.release(session);
386            if let Some(waiter) = waiter {
387                let _ = waiter.send(());
388            }
389        } else {
390            let session_count = {
391                let mut sessions = self.inner.write();
392                sessions.release(session);
393                if sessions.num_opened() < self.config.min_opened && !sessions.waiters.is_empty() {
394                    sessions.reserve(self.config.max_opened, self.config.inc_step)
395                } else {
396                    0
397                }
398            };
399            if session_count > 0 {
400                let _ = self.session_creation_sender.send(session_count);
401            }
402        }
403    }
404
405    async fn close(&self) {
406        let empty = VecDeque::new();
407        let deleting_sessions = { mem::replace(&mut self.inner.write().available_sessions, empty) };
408        for mut session in deleting_sessions {
409            session.delete().await;
410        }
411
412        self.remove_orphans().await;
413    }
414
415    fn snapshot_fn(&self) -> SessionPoolStatsFn {
416        let inner = self.inner.clone();
417        let max_allowed = self.config.max_opened;
418        Arc::new(move || {
419            let sessions = inner.read();
420            SessionPoolSnapshot {
421                open_sessions: sessions.num_opened(),
422                sessions_in_use: sessions.num_inuse,
423                idle_sessions: sessions.available_sessions.len(),
424                max_allowed_sessions: max_allowed,
425                max_in_use_last_window: sessions.max_inuse_window,
426                has_multiplexed_session: false,
427            }
428        })
429    }
430
431    async fn remove_orphans(&self) {
432        let empty = vec![];
433        let deleting_sessions = { mem::replace(&mut self.inner.write().orphans, empty) };
434        tracing::trace!("remove {} orphan sessions", deleting_sessions.len());
435        for mut session in deleting_sessions {
436            session.delete().await;
437        }
438    }
439}
440
441#[derive(Clone, Debug)]
442pub struct SessionConfig {
443    /// max_opened is the maximum number of opened sessions allowed by the session
444    /// pool. If the client tries to open a session and there are already
445    /// max_opened sessions, it will block until one becomes available or the
446    /// context passed to the client method is canceled or times out.
447    pub max_opened: usize,
448
449    /// min_opened is the minimum number of opened sessions that the session pool
450    /// tries to maintain. Session pool won't continue to expire sessions if
451    /// number of opened connections drops below min_opened. However, if a session
452    /// is found to be broken, it will still be evicted from the session pool,
453    /// therefore it is posssible that the number of opened sessions drops below
454    /// min_opened.
455    pub min_opened: usize,
456
457    /// max_idle is the maximum number of idle sessions, pool is allowed to keep.
458    pub max_idle: usize,
459
460    /// idle_timeout is the wait time before discarding an idle session.
461    /// Sessions older than this value since they were last used will be discarded.
462    /// However, if the number of sessions is less than or equal to min_opened, it will not be discarded.
463    pub idle_timeout: Duration,
464
465    pub session_alive_trust_duration: Duration,
466
467    /// session_get_timeout is the maximum value of the waiting time that occurs when retrieving from the connection pool when there is no idle session.
468    pub session_get_timeout: Duration,
469
470    /// refresh_interval is the interval of cleanup and health check functions.
471    pub refresh_interval: Duration,
472
473    /// incStep is the number of sessions to create in one batch when at least
474    /// one more session is needed.
475    inc_step: usize,
476}
477
478impl Default for SessionConfig {
479    fn default() -> Self {
480        SessionConfig {
481            max_opened: 400,
482            min_opened: 10,
483            max_idle: 300,
484            inc_step: 25,
485            idle_timeout: Duration::from_secs(30 * 60),
486            session_alive_trust_duration: Duration::from_secs(55 * 60),
487            session_get_timeout: Duration::from_secs(1),
488            refresh_interval: Duration::from_secs(5 * 60),
489        }
490    }
491}
492
493#[derive(thiserror::Error, Debug)]
494pub enum SessionError {
495    #[error("session get time out")]
496    SessionGetTimeout,
497    #[error("failed to create session")]
498    FailedToCreateSession,
499    #[error(transparent)]
500    GRPC(#[from] Status),
501}
502
503impl TryAs<Status> for SessionError {
504    fn try_as(&self) -> Option<&Status> {
505        match self {
506            SessionError::GRPC(e) => Some(e),
507            _ => None,
508        }
509    }
510}
511
512pub(crate) struct SessionManager {
513    session_pool: SessionPool,
514    cancel: CancellationToken,
515    tasks: Mutex<Vec<JoinHandle<()>>>,
516}
517
518impl SessionManager {
519    pub async fn new(
520        database: impl Into<String>,
521        conn_pool: ConnectionManager,
522        config: SessionConfig,
523        disable_route_to_leader: bool,
524        metrics: Arc<MetricsRecorder>,
525    ) -> Result<Arc<SessionManager>, Status> {
526        let database = database.into();
527        let (sender, receiver) = mpsc::unbounded_channel();
528        let session_pool = SessionPool::new(
529            database.clone(),
530            &conn_pool,
531            sender,
532            Arc::new(config.clone()),
533            disable_route_to_leader,
534            metrics.clone(),
535        )
536        .await?;
537
538        let cancel = CancellationToken::new();
539        let task_session_cleaner = Self::spawn_health_check_task(config, session_pool.clone(), cancel.clone());
540        let task_session_creator = Self::spawn_session_creation_task(
541            session_pool.clone(),
542            database,
543            conn_pool,
544            receiver,
545            cancel.clone(),
546            disable_route_to_leader,
547        );
548
549        let sm = SessionManager {
550            session_pool,
551            cancel,
552            tasks: Mutex::new(vec![task_session_cleaner, task_session_creator]),
553        };
554        Ok(Arc::new(sm))
555    }
556
557    pub fn num_opened(&self) -> usize {
558        self.session_pool.num_opened()
559    }
560
561    pub async fn get(&self) -> Result<ManagedSession, SessionError> {
562        self.session_pool.acquire().await
563    }
564
565    pub async fn close(&self) {
566        if self.cancel.is_cancelled() {
567            return;
568        }
569        self.cancel.cancel();
570        let tasks = { mem::take(&mut *self.tasks.lock()) };
571        for task in tasks {
572            let _ = task.await;
573        }
574        self.session_pool.close().await;
575    }
576
577    fn spawn_session_creation_task(
578        session_pool: SessionPool,
579        database: String,
580        conn_pool: ConnectionManager,
581        mut rx: UnboundedReceiver<usize>,
582        cancel: CancellationToken,
583        disable_route_to_leader: bool,
584    ) -> JoinHandle<()> {
585        tokio::spawn(async move {
586            let mut tasks = JoinSet::default();
587            loop {
588                select! {
589                    biased;
590                    _ = cancel.cancelled() => break,
591                    Some(Ok((session_count, result))) = tasks.join_next(), if !tasks.is_empty() => {
592                        session_pool.inner.write().replenish(session_count, result);
593                    }
594                    session_count = rx.recv() => match session_count {
595                        Some(session_count) => {
596                            let client = conn_pool
597                                .conn()
598                                .with_metrics(session_pool.metrics.clone())
599                                .with_metadata(client_metadata(&database));
600                            let database = database.clone();
601                            tasks.spawn(async move { (session_count, batch_create_sessions(client, &database, session_count, disable_route_to_leader).await) });
602                        },
603                        None => continue
604                    },
605                }
606            }
607            tracing::trace!("shutdown session creation task.");
608        })
609    }
610
611    fn spawn_health_check_task(
612        config: SessionConfig,
613        session_pool: SessionPool,
614        cancel: CancellationToken,
615    ) -> JoinHandle<()> {
616        let start = Instant::now() + config.refresh_interval;
617        let mut interval = tokio::time::interval_at(start.into(), config.refresh_interval);
618
619        tokio::spawn(async move {
620            loop {
621                select! {
622                    _ = interval.tick() => {},
623                    _ = cancel.cancelled() => break
624                }
625                let now = Instant::now();
626
627                // remove orphans first
628                session_pool.remove_orphans().await;
629
630                // start health check
631                health_check(
632                    now + Duration::from_nanos(1),
633                    config.session_alive_trust_duration,
634                    &session_pool,
635                    cancel.clone(),
636                )
637                .await;
638            }
639            tracing::trace!("shutdown health check task.")
640        })
641    }
642}
643
644async fn health_check(
645    now: Instant,
646    session_alive_trust_duration: Duration,
647    sessions: &SessionPool,
648    cancel: CancellationToken,
649) {
650    tracing::trace!("start health check");
651    let start = Instant::now();
652    let sleep_duration = Duration::from_millis(10);
653    loop {
654        select! {
655            _ = sleep(sleep_duration) => {},
656            _ = cancel.cancelled() => break
657        }
658        let mut s = {
659            // temporary take
660            let mut locked = sessions.inner.write();
661            match locked.take() {
662                Some(mut s) => {
663                    // all the session check complete.
664                    if s.last_checked_at == now {
665                        locked.release(s);
666                        break;
667                    }
668                    if std::cmp::max(s.last_used_at, s.last_pong_at) + session_alive_trust_duration >= now {
669                        s.last_checked_at = now;
670                        locked.release(s);
671                        continue;
672                    }
673                    s
674                }
675                None => break,
676            }
677        };
678
679        let request = ping_query_request(s.session.name.clone());
680        match s.spanner_client.execute_sql(request, true, None).await {
681            Ok(_) => {
682                s.last_checked_at = now;
683                s.last_pong_at = now;
684                sessions.recycle(s);
685            }
686            Err(_) => {
687                s.delete().await;
688                sessions.recycle(s);
689            }
690        }
691    }
692    tracing::trace!("end health check elapsed={}msec", start.elapsed().as_millis());
693}
694
695async fn batch_create_sessions(
696    spanner_client: Client,
697    database: &str,
698    mut remaining_create_count: usize,
699    disable_route_to_leader: bool,
700) -> Result<Vec<SessionHandle>, Status> {
701    let mut created = Vec::with_capacity(remaining_create_count);
702    while remaining_create_count > 0 {
703        let sessions = batch_create_session(
704            spanner_client.clone(),
705            database,
706            remaining_create_count,
707            disable_route_to_leader,
708        )
709        .await?;
710        // Spanner could return less sessions than requested.
711        // In that case, we should do another call using the same gRPC channel.
712        let actually_created = sessions.len();
713        remaining_create_count -= actually_created;
714        created.extend(sessions);
715    }
716    Ok(created)
717}
718
719async fn batch_create_session(
720    mut spanner_client: Client,
721    database: &str,
722    session_count: usize,
723    disable_route_to_leader: bool,
724) -> Result<Vec<SessionHandle>, Status> {
725    let request = BatchCreateSessionsRequest {
726        database: database.to_string(),
727        session_template: None,
728        session_count: session_count as i32,
729    };
730
731    tracing::debug!("spawn session creation request : session_count = {}", session_count);
732    let response = spanner_client
733        .batch_create_sessions(request, disable_route_to_leader, None)
734        .await?
735        .into_inner();
736
737    let now = Instant::now();
738    Ok(response
739        .session
740        .into_iter()
741        .map(|s| SessionHandle::new(s, spanner_client.clone(), now))
742        .collect::<Vec<SessionHandle>>())
743}
744
745pub(crate) fn client_metadata(database: &str) -> MetadataMap {
746    let mut metadata = MetadataMap::new();
747    metadata.insert("google-cloud-resource-prefix", database.parse().unwrap());
748    metadata
749}
750
751#[cfg(test)]
752mod tests {
753    use std::sync::atomic::{AtomicI64, Ordering};
754    use std::sync::Arc;
755    use std::time::{Duration, Instant};
756
757    use parking_lot::RwLock;
758    use serial_test::serial;
759    use tokio::time::sleep;
760    use tokio_util::sync::CancellationToken;
761
762    use google_cloud_gax::conn::{ConnectionOptions, Environment};
763    use google_cloud_googleapis::spanner::v1::ExecuteSqlRequest;
764
765    use crate::apiv1::conn_pool::ConnectionManager;
766    use crate::metrics::MetricsRecorder;
767    use crate::session::{
768        batch_create_sessions, client_metadata, health_check, SessionConfig, SessionError, SessionManager,
769    };
770
771    pub const DATABASE: &str = "projects/local-project/instances/test-instance/databases/local-database";
772
773    #[ctor::ctor]
774    fn init() {
775        let filter = tracing_subscriber::filter::EnvFilter::from_default_env()
776            .add_directive("google_cloud_spanner=trace".parse().unwrap());
777        let _ = tracing_subscriber::fmt().with_env_filter(filter).try_init();
778    }
779
780    async fn assert_rush(use_invalidate: bool, config: SessionConfig) -> Arc<SessionManager> {
781        let cm = ConnectionManager::new(
782            4,
783            &Environment::Emulator("localhost:9010".to_string()),
784            "",
785            &ConnectionOptions::default(),
786        )
787        .await
788        .unwrap();
789        let sm = SessionManager::new(DATABASE, cm, config, false, Arc::new(MetricsRecorder::default()))
790            .await
791            .unwrap();
792
793        let counter = Arc::new(AtomicI64::new(0));
794        let mut spawns = Vec::with_capacity(100);
795        for _ in 0..100 {
796            let sm = sm.clone();
797            let counter = Arc::clone(&counter);
798            spawns.push(tokio::spawn(async move {
799                let mut session = sm.get().await.unwrap();
800                if use_invalidate {
801                    session.delete().await;
802                }
803                counter.fetch_add(1, Ordering::SeqCst);
804                sleep(Duration::from_millis(300)).await;
805            }));
806        }
807        for handler in spawns {
808            let _ = handler.await;
809        }
810        sm
811    }
812
813    #[tokio::test(flavor = "multi_thread")]
814    #[serial]
815    async fn test_health_check_checked() {
816        let cm = ConnectionManager::new(
817            4,
818            &Environment::Emulator("localhost:9010".to_string()),
819            "",
820            &ConnectionOptions::default(),
821        )
822        .await
823        .unwrap();
824        let session_alive_trust_duration = Duration::from_millis(10);
825        let config = SessionConfig {
826            min_opened: 5,
827            session_alive_trust_duration,
828            max_opened: 5,
829            ..Default::default()
830        };
831        let sm = std::sync::Arc::new(
832            SessionManager::new(DATABASE, cm, config, false, Arc::new(MetricsRecorder::default()))
833                .await
834                .unwrap(),
835        );
836        sleep(Duration::from_secs(1)).await;
837
838        let cancel = CancellationToken::new();
839        health_check(Instant::now(), session_alive_trust_duration, &sm.session_pool, cancel.clone()).await;
840
841        assert_eq!(sm.num_opened(), 5);
842        tokio::time::sleep(Duration::from_millis(500)).await;
843        cancel.cancel();
844    }
845
846    #[tokio::test(flavor = "multi_thread")]
847    #[serial]
848    async fn test_health_check_not_checked() {
849        let cm = ConnectionManager::new(
850            4,
851            &Environment::Emulator("localhost:9010".to_string()),
852            "",
853            &ConnectionOptions::default(),
854        )
855        .await
856        .unwrap();
857        let session_alive_trust_duration = Duration::from_secs(10);
858        let config = SessionConfig {
859            min_opened: 5,
860            session_alive_trust_duration,
861            max_opened: 5,
862            ..Default::default()
863        };
864        let sm = Arc::new(
865            SessionManager::new(DATABASE, cm, config, false, Arc::new(MetricsRecorder::default()))
866                .await
867                .unwrap(),
868        );
869        sleep(Duration::from_secs(1)).await;
870
871        let cancel = CancellationToken::new();
872        health_check(Instant::now(), session_alive_trust_duration, &sm.session_pool, cancel.clone()).await;
873
874        assert_eq!(sm.num_opened(), 5);
875        sleep(Duration::from_millis(500)).await;
876        cancel.cancel();
877    }
878
879    #[tokio::test(flavor = "multi_thread")]
880    #[serial]
881    async fn test_increase_session_and_idle_session_expired() {
882        let conn_pool = ConnectionManager::new(
883            4,
884            &Environment::Emulator("localhost:9010".to_string()),
885            "",
886            &ConnectionOptions::default(),
887        )
888        .await
889        .unwrap();
890        let config = SessionConfig {
891            idle_timeout: Duration::from_millis(10),
892            min_opened: 10,
893            max_idle: 20,
894            max_opened: 45,
895            ..Default::default()
896        };
897        let sm = SessionManager::new(DATABASE, conn_pool, config, false, Arc::new(MetricsRecorder::default()))
898            .await
899            .unwrap();
900        {
901            let mut sessions = Vec::new();
902            for _ in 0..45 {
903                sessions.push(sm.get().await.unwrap());
904            }
905
906            // all the session are using
907            assert_eq!(sm.num_opened(), 45);
908            assert_eq!(sm.session_pool.inner.read().num_inuse, 45, "all the session are using");
909            sleep(Duration::from_secs(1)).await;
910        }
911
912        // idle session removed after drop
913        let sessions = sm.session_pool.inner.read();
914        assert_eq!(sessions.num_inuse, 0, "invalid num_inuse");
915        assert_eq!(sessions.available_sessions.len(), 20, "invalid available sessions");
916        assert_eq!(sessions.num_opened(), 20, "invalid num open");
917        assert_eq!(sessions.waiters.len(), 0, "session waiters is 0");
918    }
919
920    #[tokio::test(flavor = "multi_thread")]
921    #[serial]
922    async fn test_too_many_session_timeout() {
923        let conn_pool = ConnectionManager::new(
924            4,
925            &Environment::Emulator("localhost:9010".to_string()),
926            "",
927            &ConnectionOptions::default(),
928        )
929        .await
930        .unwrap();
931        let config = SessionConfig {
932            idle_timeout: Duration::from_millis(10),
933            min_opened: 10,
934            max_idle: 20,
935            max_opened: 45,
936            session_get_timeout: Duration::from_secs(1),
937            ..Default::default()
938        };
939        let sm = Arc::new(
940            SessionManager::new(DATABASE, conn_pool, config.clone(), false, Arc::new(MetricsRecorder::default()))
941                .await
942                .unwrap(),
943        );
944        let mu = Arc::new(RwLock::new(Vec::new()));
945        let mut awaiters = Vec::with_capacity(100);
946        for _ in 0..100 {
947            let sm = sm.clone();
948            let mu = mu.clone();
949            awaiters.push(tokio::spawn(async move {
950                let session = sm.get().await;
951                mu.write().push(session);
952                0
953            }))
954        }
955        for handler in awaiters {
956            let _ = handler.await;
957        }
958        let sessions = mu.read();
959        for i in 0..sessions.len() - 1 {
960            let session = &sessions[i];
961            if i >= config.max_opened {
962                assert!(session.is_err(), "must err {i}");
963                match session.as_ref().err().unwrap() {
964                    SessionError::SessionGetTimeout => {}
965                    _ => {
966                        panic!("must be session timeout error")
967                    }
968                }
969            } else {
970                assert!(session.is_ok(), "must ok {i}");
971            }
972        }
973        let pool = sm.session_pool.inner.read();
974        assert_eq!(pool.num_opened(), config.max_opened);
975        assert_eq!(pool.waiters.len(), 100 - config.max_opened); //include timeout sessions
976    }
977
978    #[tokio::test(flavor = "multi_thread")]
979    #[serial]
980    async fn test_rush_invalidate() {
981        let config = SessionConfig {
982            session_get_timeout: Duration::from_secs(20),
983            min_opened: 10,
984            max_idle: 20,
985            max_opened: 45,
986            ..Default::default()
987        };
988        let sm = assert_rush(true, config.clone()).await;
989        {
990            let sessions = sm.session_pool.inner.read();
991            let available_sessions = sessions.available_sessions.len();
992            assert_eq!(sessions.num_inuse, 0);
993            assert_eq!(sessions.waiters.len(), 0);
994            assert_eq!(sessions.orphans.len(), 0);
995            assert!(
996                available_sessions <= config.max_opened && available_sessions >= config.min_opened,
997                "now is {available_sessions}"
998            );
999        }
1000        sm.close().await;
1001    }
1002
1003    #[tokio::test(flavor = "multi_thread")]
1004    #[serial]
1005    async fn test_rush() {
1006        let config = SessionConfig {
1007            min_opened: 10,
1008            max_idle: 20,
1009            max_opened: 45,
1010            ..Default::default()
1011        };
1012        let sm = assert_rush(false, config.clone()).await;
1013        {
1014            let sessions = sm.session_pool.inner.read();
1015            let available_sessions = sessions.available_sessions.len();
1016            assert_eq!(sessions.num_inuse, 0);
1017            assert_eq!(sessions.waiters.len(), 0);
1018            assert_eq!(sessions.orphans.len(), 0);
1019            assert!(
1020                available_sessions <= config.max_opened && available_sessions >= config.min_opened,
1021                "now is {available_sessions}"
1022            );
1023        }
1024        sm.close().await;
1025    }
1026
1027    #[tokio::test(flavor = "multi_thread")]
1028    #[serial]
1029    async fn test_rush_with_invalidate() {
1030        let config = SessionConfig {
1031            min_opened: 10,
1032            max_idle: 20,
1033            max_opened: 45,
1034            ..Default::default()
1035        };
1036        let sm = assert_rush(true, config.clone()).await;
1037        {
1038            let sessions = sm.session_pool.inner.read();
1039            let available_sessions = sessions.available_sessions.len();
1040            assert_eq!(sessions.num_inuse, 0);
1041            assert_eq!(sessions.waiters.len(), 0);
1042            assert_eq!(sessions.orphans.len(), 0);
1043            assert!(
1044                available_sessions <= config.max_opened && available_sessions >= config.min_opened,
1045                "now is {available_sessions}"
1046            );
1047        }
1048        sm.close().await;
1049    }
1050
1051    #[tokio::test(flavor = "multi_thread")]
1052    #[serial]
1053    async fn test_rush_with_health_check() {
1054        let config = SessionConfig {
1055            session_alive_trust_duration: Duration::from_millis(10),
1056            refresh_interval: Duration::from_millis(250),
1057            session_get_timeout: Duration::from_secs(20),
1058            min_opened: 10,
1059            max_idle: 20,
1060            max_opened: 45,
1061            ..Default::default()
1062        };
1063        let sm = assert_rush(false, config.clone()).await;
1064        sleep(Duration::from_secs(2)).await;
1065        {
1066            let sessions = sm.session_pool.inner.read();
1067            let available_sessions = sessions.available_sessions.len();
1068            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1069            assert_eq!(sessions.waiters.len(), 0);
1070            assert_eq!(sessions.orphans.len(), 0);
1071            assert!(
1072                available_sessions <= config.max_opened && available_sessions >= config.max_idle - 1,
1073                "now is {available_sessions}"
1074            );
1075        }
1076        sm.close().await;
1077    }
1078
1079    #[tokio::test(flavor = "multi_thread")]
1080    #[serial]
1081    async fn test_rush_with_health_check_and_invalidate() {
1082        let config = SessionConfig {
1083            session_alive_trust_duration: Duration::from_millis(10),
1084            refresh_interval: Duration::from_millis(250),
1085            session_get_timeout: Duration::from_secs(20),
1086            min_opened: 10,
1087            max_idle: 20,
1088            max_opened: 45,
1089            ..Default::default()
1090        };
1091        let sm = assert_rush(true, config.clone()).await;
1092        sleep(Duration::from_secs(2)).await;
1093        {
1094            let sessions = sm.session_pool.inner.read();
1095            let available_sessions = sessions.available_sessions.len();
1096            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1097            assert_eq!(sessions.waiters.len(), 0);
1098            assert_eq!(sessions.orphans.len(), 0);
1099            assert!(
1100                available_sessions <= config.max_opened && available_sessions >= config.min_opened - 1,
1101                "now is {available_sessions}"
1102            );
1103        }
1104        sm.close().await;
1105    }
1106
1107    #[tokio::test(flavor = "multi_thread")]
1108    #[serial]
1109    async fn test_rush_with_idle_expired() {
1110        let config = SessionConfig {
1111            min_opened: 10,
1112            max_idle: 20,
1113            max_opened: 45,
1114            idle_timeout: Duration::from_millis(1),
1115            ..Default::default()
1116        };
1117        let sm = assert_rush(false, config.clone()).await;
1118        {
1119            let sessions = sm.session_pool.inner.read();
1120            assert_eq!(sessions.num_inuse, 0);
1121            assert_eq!(sessions.waiters.len(), 0);
1122            assert_eq!(sessions.orphans.len(), config.max_opened - config.max_idle);
1123            assert_eq!(sessions.available_sessions.len(), config.max_idle);
1124        }
1125        sm.close().await;
1126    }
1127
1128    #[tokio::test(flavor = "multi_thread")]
1129    #[serial]
1130    async fn test_rush_with_health_check_and_idle_expired() {
1131        let config = SessionConfig {
1132            session_alive_trust_duration: Duration::from_millis(10),
1133            refresh_interval: Duration::from_millis(250),
1134            session_get_timeout: Duration::from_secs(20),
1135            min_opened: 10,
1136            max_idle: 20,
1137            max_opened: 45,
1138            idle_timeout: Duration::from_millis(1),
1139            ..Default::default()
1140        };
1141        let sm = assert_rush(false, config.clone()).await;
1142        sleep(Duration::from_secs(1)).await;
1143        {
1144            let sessions = sm.session_pool.inner.read();
1145            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1146            assert_eq!(sessions.waiters.len(), 0);
1147            assert_eq!(sessions.orphans.len(), 0);
1148            let available_sessions = sessions.available_sessions.len();
1149            assert!(
1150                available_sessions >= config.min_opened - 1 && available_sessions <= config.max_idle,
1151                "now is {available_sessions}"
1152            );
1153        }
1154        sm.close().await;
1155    }
1156
1157    #[tokio::test(flavor = "multi_thread")]
1158    #[serial]
1159    async fn test_rush_with_health_check_and_idle_expired_and_invalid() {
1160        let config = SessionConfig {
1161            session_alive_trust_duration: Duration::from_millis(10),
1162            refresh_interval: Duration::from_millis(250),
1163            session_get_timeout: Duration::from_secs(20),
1164            min_opened: 10,
1165            max_idle: 20,
1166            max_opened: 45,
1167            idle_timeout: Duration::from_millis(1),
1168            ..Default::default()
1169        };
1170        let sm = assert_rush(true, config.clone()).await;
1171        sleep(Duration::from_secs(2)).await;
1172        {
1173            let sessions = sm.session_pool.inner.read();
1174            assert!(sessions.num_inuse <= 1, "num_inuse is {}", sessions.num_inuse);
1175            // health checker removes orphans
1176            assert_eq!(sessions.orphans.len(), 0);
1177            assert_eq!(sessions.waiters.len(), 0, "invalid waiters");
1178            let available_sessions = sessions.available_sessions.len();
1179            assert!(
1180                available_sessions >= config.min_opened - 1 && available_sessions <= config.max_idle,
1181                "now is {available_sessions}"
1182            );
1183        }
1184        sm.close().await;
1185    }
1186
1187    #[tokio::test(flavor = "multi_thread")]
1188    #[serial]
1189    async fn test_close() {
1190        let cm = ConnectionManager::new(
1191            4,
1192            &Environment::Emulator("localhost:9010".to_string()),
1193            "",
1194            &ConnectionOptions::default(),
1195        )
1196        .await
1197        .unwrap();
1198        let config = SessionConfig::default();
1199        let sm = SessionManager::new(DATABASE, cm, config.clone(), false, Arc::new(MetricsRecorder::default()))
1200            .await
1201            .unwrap();
1202        assert_eq!(sm.num_opened(), config.min_opened);
1203        sm.close().await;
1204        assert_eq!(sm.num_opened(), 0);
1205        assert_eq!(sm.session_pool.inner.read().orphans.len(), 0);
1206    }
1207
1208    #[tokio::test(flavor = "multi_thread")]
1209    #[serial]
1210    async fn test_batch_create_sessions() {
1211        let cm = ConnectionManager::new(
1212            1,
1213            &Environment::Emulator("localhost:9010".to_string()),
1214            "",
1215            &ConnectionOptions::default(),
1216        )
1217        .await
1218        .unwrap();
1219        let client = cm
1220            .conn()
1221            .with_metrics(Arc::new(MetricsRecorder::default()))
1222            .with_metadata(client_metadata(DATABASE));
1223        let session_count = 125;
1224        let result = batch_create_sessions(client.clone(), DATABASE, session_count, false).await;
1225        match result {
1226            Ok(created) => {
1227                assert_eq!(session_count, created.len());
1228                for mut s in created {
1229                    let ping_result = s
1230                        .spanner_client
1231                        .execute_sql(
1232                            ExecuteSqlRequest {
1233                                session: s.session.name.to_string(),
1234                                transaction: None,
1235                                sql: "SELECT 1".to_string(),
1236                                params: None,
1237                                param_types: Default::default(),
1238                                resume_token: vec![],
1239                                query_mode: 0,
1240                                partition_token: vec![],
1241                                seqno: 0,
1242                                query_options: None,
1243                                request_options: None,
1244                                directed_read_options: None,
1245                                data_boost_enabled: false,
1246                                last_statement: false,
1247                            },
1248                            false,
1249                            None,
1250                        )
1251                        .await;
1252                    assert!(ping_result.is_ok());
1253                }
1254            }
1255            Err(err) => panic!("{err:?}"),
1256        }
1257    }
1258}