1use super::*;
20use std::{vec, time::Duration};
21
22use deadpool::managed::{Manager, Pool, PoolBuilder, PoolConfig, Hook};
23
24use tonic::transport::{Endpoint, Uri};
25use tower::ServiceExt;
26
27use payload::YdbResponseWithResult;
28use generated::ydb::discovery::{EndpointInfo, ListEndpointsRequest};
29use auth::Credentials;
30use crate::client::YdbEndpoint;
31
32
33type YdbEndpoints = std::sync::RwLock<Vec<YdbEndpoint>>;
34pub type YdbPool<C> = Pool<ConnectionManager<C>>;
35
36impl From<EndpointInfo> for YdbEndpoint {
37 fn from(value: EndpointInfo) -> Self {
38 Self {
39 ssl: value.ssl,
40 host: value.address,
41 port: value.port as u16,
42 load_factor: value.load_factor,
43 }
44 }
45}
46
47
48fn make_endpoint(info: &YdbEndpoint) -> Endpoint {
49 let uri: tonic::transport::Uri = format!("{}://{}:{}", info.scheme(), info.host, info.port).try_into().unwrap();
50 let mut e = Endpoint::from(uri).tcp_keepalive(Some(std::time::Duration::from_secs(15)));
51 if info.ssl {
52 e = e.tls_config(Default::default()).unwrap()
53 }
54 e
55}
56
57pub trait GetScheme {
58 fn get_scheme(&self) -> &'static str;
59}
60
61impl GetScheme for EndpointInfo {
62 fn get_scheme(&self) -> &'static str {
63 if self.ssl { "grpcs" } else { "grpc" }
64 }
65}
66
67pub struct ConnectionManager<C> {
68 creds: C,
69 db_name: AsciiValue,
70 endpoints: YdbEndpoints,
71}
72
73impl<C: Credentials> ConnectionManager<C> {
74 pub fn next_endpoint(&self) -> Endpoint {
75 let endpoints = self.endpoints.read().unwrap();
76 if endpoints.len() == 1 {
77 return endpoints.first().unwrap().make_endpoint();
78 } else if endpoints.is_empty() {
79 panic!("List of endpoints is empty");
80 }
81 let mut rng = rand::thread_rng();
82 use rand::Rng;
83 let e1 = rng.gen::<usize>() % endpoints.len();
84 let mut e2 = e1;
85 while e2 == e1 { e2 = rng.gen::<usize>() % endpoints.len();
87 }
88 let e1 = &endpoints[e1];
89 let e2 = &endpoints[e2];
90 let endpoint = if e1.load_factor < e2.load_factor {e1} else { e2 };
91 make_endpoint(&endpoint)
92 }
93}
94
95#[async_trait::async_trait]
96impl <C: Credentials + Sync> Manager for ConnectionManager<C> {
97 type Type = YdbConnection<C>;
98
99 type Error = tonic::transport::Error;
100
101 async fn create(&self) -> Result<Self::Type, Self::Error> {
102 let endpoint = self.next_endpoint();
103 let channel = endpoint.connect().await?;
104 let db_name = self.db_name.clone();
105 let creds = self.creds.clone();
106 Ok(YdbConnection::new(channel, db_name, creds))
107 }
108
109 async fn recycle(&self, obj: &mut Self::Type) -> deadpool::managed::RecycleResult<Self::Error> {
110 obj.ready().await?;
111 Ok(())
112 }
113}
114
115pub struct YdbPoolBuilder<C: Credentials + Send + Sync> {
117 inner: PoolBuilder<ConnectionManager<C>>,
118 update_interval: Duration,
119}
120
121macro_rules! delegate {
122 ($( $fun:ident($param:ty), )+) => { $(
123 pub fn $fun(mut self, v: $param) -> Self {
124 self.inner = self.inner.$fun(v);
125 self
126 }
127 )+ };
128}
129impl<C: Credentials + Send + Sync> YdbPoolBuilder<C> {
131 pub fn new(creds: C, db_name: AsciiValue, endpoint: YdbEndpoint) -> Self {
132 let endpoints = std::sync::RwLock::new(vec![endpoint]);
133 let inner = Pool::builder(ConnectionManager {creds, db_name, endpoints});
134 let update_interval = Duration::from_secs(77);
135 Self {inner, update_interval}
136 }
137 pub fn update_interval(mut self, interval: Duration) -> Self {
139 self.update_interval = interval;
140 self
141 }
142 delegate!{
143 config(PoolConfig),
144 create_timeout(Option<Duration>),
145 max_size(usize),
146 post_create(impl Into<Hook<ConnectionManager<C>>>),
147 post_recycle(impl Into<Hook<ConnectionManager<C>>>),
148 pre_recycle(impl Into<Hook<ConnectionManager<C>>>),
149 recycle_timeout(Option<Duration>),
150 runtime(deadpool::Runtime),
151 timeouts(deadpool::managed::Timeouts),
152 wait_timeout(Option<Duration>),
153 }
154 pub fn build(self) -> Result<Pool<ConnectionManager<C>>, deadpool::managed::BuildError<tonic::transport::Error>> {
155 let pool = self.inner.build()?;
156 let result = pool.clone();
157 let db_name = pool.manager().db_name.to_str().unwrap().to_owned();
158 tokio::spawn(async move {
159 loop {
160 if pool.is_closed() {
161 log::debug!("Connection pool closed");
162 break;
163 }
164 if let Err(e) = update_endpoints(&pool, db_name.clone()).await {
165 log::error!("Error on update endpoints for pool: {e:?}");
166 }
167 tokio::time::sleep(self.update_interval).await;
168 }
169 });
170 Ok(result)
171 }
172}
173
174async fn update_endpoints<C: Credentials + Send + Sync>(pool: &Pool<ConnectionManager<C>>, database: String) -> Result<(), Box<dyn std::error::Error>> {
175 let mut service = pool.get().await?;
176 let mut discovery = service.discovery();
177 let response = discovery.list_endpoints(ListEndpointsRequest{database, ..Default::default()}).await?;
178 let endpoints: Vec<_> = response.into_inner().result()?.endpoints.into_iter().map(From::from).collect();
179 log::debug!("Pool endpoints updated ({} endpoints)", endpoints.len());
180 *pool.manager().endpoints.write().unwrap() = endpoints;
181 Ok(())
182}
183pub fn to_endpoint_info(value: Uri) -> Result<EndpointInfo, String> {
184 let mut e = EndpointInfo::default();
185 e.ssl = match value.scheme_str() {
186 Some("grpc") => false,
187 Some("grpcs") => true,
188 _ => return Err("Unknown protocol".to_owned()),
189 };
190 e.address = value.host().ok_or("no host")?.to_owned();
191 e.port = value.port_u16().ok_or("no port")? as u32;
192 Ok(e)
193}