ydb_unofficial/
pool.rs

1//! Implementation of pool of [`YdbConnection`].
2//! Uses method `list_endpoints` of `DiscoveryServiceClient` to create pool on multiple endpoints
3//! # Examples
4//! ```rust
5//! # #[tokio::main]
6//! # async fn main() {
7//! let db_name = std::env::var("DB_NAME").expect("DB_NAME not set");
8//! let creds = std::env::var("DB_TOKEN").expect("DB_TOKEN not set");
9//! let ep = ydb_unofficial::client::YdbEndpoint {ssl: true, host: "ydb.serverless.yandexcloud.net".to_owned(), port: 2135, load_factor: 0.0};
10//! let pool = ydb_unofficial::pool::YdbPoolBuilder::new(creds, db_name.try_into().unwrap(), ep).build().unwrap();
11//! let mut conn = pool.get().await.unwrap();
12//! let mut table_client = conn.table();
13//! //do something...
14//! let mut conn2 = pool.get().await.unwrap();
15//! //do another staff
16//! pool.close();
17//! # }
18//! ```
19use super::*;
20use std::{vec, time::Duration};
21
22use deadpool::managed::{Manager, Pool, PoolBuilder, PoolConfig, Hook};
23
24use tonic::transport::{Endpoint, Uri};
25use tower::ServiceExt;
26
27use payload::YdbResponseWithResult;
28use generated::ydb::discovery::{EndpointInfo, ListEndpointsRequest};
29use auth::Credentials;
30use crate::client::YdbEndpoint;
31
32
33type YdbEndpoints = std::sync::RwLock<Vec<YdbEndpoint>>;
34pub type YdbPool<C> = Pool<ConnectionManager<C>>;
35
36impl From<EndpointInfo> for YdbEndpoint {
37    fn from(value: EndpointInfo) -> Self {
38        Self {
39            ssl: value.ssl,
40            host: value.address,
41            port: value.port as u16,
42            load_factor: value.load_factor,
43        }
44    }
45}
46
47
48fn make_endpoint(info: &YdbEndpoint) -> Endpoint {
49    let uri: tonic::transport::Uri = format!("{}://{}:{}", info.scheme(), info.host, info.port).try_into().unwrap();
50    let mut e = Endpoint::from(uri).tcp_keepalive(Some(std::time::Duration::from_secs(15)));
51    if info.ssl {
52        e = e.tls_config(Default::default()).unwrap()
53    }
54    e
55}
56
57pub trait GetScheme {
58    fn get_scheme(&self) -> &'static str;
59}
60
61impl GetScheme for EndpointInfo {
62    fn get_scheme(&self) -> &'static str {
63        if self.ssl { "grpcs" } else { "grpc" }
64    }
65}
66
67pub struct ConnectionManager<C> {
68    creds: C,
69    db_name: AsciiValue,
70    endpoints: YdbEndpoints,
71}
72
73impl<C: Credentials> ConnectionManager<C> {
74    pub fn next_endpoint(&self) -> Endpoint {
75        let endpoints = self.endpoints.read().unwrap();
76        if endpoints.len() == 1 {
77            return endpoints.first().unwrap().make_endpoint();
78        } else if endpoints.is_empty() {
79            panic!("List of endpoints is empty");
80        }
81        let mut rng = rand::thread_rng();
82        use rand::Rng;
83        let e1 = rng.gen::<usize>() % endpoints.len();
84        let mut e2 = e1;
85        while e2 == e1 { //TODO: кажется, это что-то неоптимальное
86            e2 = rng.gen::<usize>() % endpoints.len();
87        }
88        let e1 = &endpoints[e1];
89        let e2 = &endpoints[e2];
90        let endpoint = if e1.load_factor < e2.load_factor {e1} else { e2 };
91        make_endpoint(&endpoint)
92    }
93}
94
95#[async_trait::async_trait]
96impl <C: Credentials + Sync> Manager for ConnectionManager<C> {
97    type Type = YdbConnection<C>;
98
99    type Error = tonic::transport::Error;
100
101    async fn create(&self) ->  Result<Self::Type, Self::Error> {
102        let endpoint = self.next_endpoint();
103        let channel = endpoint.connect().await?;
104        let db_name = self.db_name.clone();
105        let creds = self.creds.clone();
106        Ok(YdbConnection::new(channel, db_name, creds))
107    }
108
109    async fn recycle(&self, obj: &mut Self::Type) ->  deadpool::managed::RecycleResult<Self::Error> {
110        obj.ready().await?;
111        Ok(())
112    }
113}
114
115/// Builder for pool of [`YdbConnection`]
116pub struct YdbPoolBuilder<C: Credentials + Send + Sync> {
117    inner: PoolBuilder<ConnectionManager<C>>,
118    update_interval: Duration,
119}
120
121macro_rules! delegate {
122    ($( $fun:ident($param:ty), )+) => { $(
123        pub fn $fun(mut self, v: $param) -> Self {
124            self.inner = self.inner.$fun(v);
125            self
126        }
127    )+ };
128}
129/// Wrapper on [`PoolBuilder`] for YdbConnection.
130impl<C: Credentials + Send + Sync> YdbPoolBuilder<C> {
131    pub fn new(creds: C, db_name: AsciiValue, endpoint: YdbEndpoint) -> Self {
132        let endpoints =  std::sync::RwLock::new(vec![endpoint]);
133        let inner = Pool::builder(ConnectionManager {creds, db_name, endpoints});
134        let update_interval = Duration::from_secs(77);
135        Self {inner, update_interval}
136    }
137    /// Set period to update endpoints for pool. Default is 77 seconds.
138    pub fn update_interval(mut self, interval: Duration) -> Self {
139        self.update_interval = interval;
140        self
141    }
142    delegate!{ 
143        config(PoolConfig),
144        create_timeout(Option<Duration>),
145        max_size(usize),
146        post_create(impl Into<Hook<ConnectionManager<C>>>),
147        post_recycle(impl Into<Hook<ConnectionManager<C>>>),
148        pre_recycle(impl Into<Hook<ConnectionManager<C>>>),
149        recycle_timeout(Option<Duration>),
150        runtime(deadpool::Runtime),
151        timeouts(deadpool::managed::Timeouts),
152        wait_timeout(Option<Duration>),
153    }
154    pub fn build(self) -> Result<Pool<ConnectionManager<C>>, deadpool::managed::BuildError<tonic::transport::Error>> {
155        let pool = self.inner.build()?;
156        let result = pool.clone();
157        let db_name = pool.manager().db_name.to_str().unwrap().to_owned();
158        tokio::spawn(async move {
159            loop {
160                if pool.is_closed() {
161                    log::debug!("Connection pool closed");
162                    break;
163                }
164                if let Err(e) = update_endpoints(&pool, db_name.clone()).await {
165                    log::error!("Error on update endpoints for pool: {e:?}");
166                }
167                tokio::time::sleep(self.update_interval).await;
168            }
169        });
170        Ok(result)
171    }
172}
173
174async fn update_endpoints<C: Credentials + Send + Sync>(pool: &Pool<ConnectionManager<C>>, database: String) -> Result<(), Box<dyn std::error::Error>> {
175    let mut service = pool.get().await?;
176    let mut discovery = service.discovery();
177    let response = discovery.list_endpoints(ListEndpointsRequest{database, ..Default::default()}).await?; 
178    let endpoints: Vec<_> = response.into_inner().result()?.endpoints.into_iter().map(From::from).collect();
179    log::debug!("Pool endpoints updated ({} endpoints)", endpoints.len());
180    *pool.manager().endpoints.write().unwrap() = endpoints;
181    Ok(())
182}
183pub fn to_endpoint_info(value: Uri) -> Result<EndpointInfo, String> {
184    let mut e = EndpointInfo::default();
185    e.ssl = match value.scheme_str() {
186        Some("grpc") => false,
187        Some("grpcs") => true,
188        _ => return Err("Unknown protocol".to_owned()),
189    };
190    e.address = value.host().ok_or("no host")?.to_owned();
191    e.port = value.port_u16().ok_or("no port")? as u32;
192    Ok(e)
193}