1use std::net::SocketAddr;
28use std::time::{Duration, Instant};
29
30use rand::{Rng, SeedableRng};
31
32use crate::error::{ClientError, ProtocolError, Result};
33use crate::internal::{BufferPool, Driver, OwnedBuf};
34use crate::protocol::{
35 Account, AccountBalance, AccountFilter, Command, CreateAccountsResult, CreateTransfersResult,
36 Header, Message, Operation, QueryFilter, RegisterRequest, RegisterResult, RequestBuilder,
37 Transfer, HEADER_SIZE, MESSAGE_SIZE_MAX,
38};
39
40const CLIENT_RELEASE: u32 = 1;
42
43#[derive(Clone, Copy, Debug, Eq, PartialEq)]
45enum State {
46 Disconnected,
47 Registering,
48 Ready,
49 Shutdown,
50}
51
52pub struct Client {
84 id: u128,
86 cluster: u128,
88 replica_count: u8,
90 driver: Driver,
92 state: State,
94 view: u32,
96 session: u64,
98 request_number: u32,
100 parent: u128,
102 batch_size_limit: Option<u32>,
104 rng: rand::rngs::StdRng,
106 send_buffer: Vec<u8>,
108 buffer_pool: BufferPool,
110 request_timeout: Duration,
112 request_timeout_max: Duration,
114}
115
116impl Client {
117 pub async fn connect(cluster: u128, addresses: &str) -> Result<Self> {
133 Self::builder()
134 .cluster(cluster)
135 .addresses(addresses)?
136 .build()
137 .await
138 }
139
140 pub fn builder() -> ClientBuilder {
154 ClientBuilder::new()
155 }
156
157 pub fn id(&self) -> u128 {
159 self.id
160 }
161
162 pub fn cluster(&self) -> u128 {
164 self.cluster
165 }
166
167 pub fn is_ready(&self) -> bool {
169 self.state == State::Ready
170 }
171
172 pub fn batch_size_limit(&self) -> Option<u32> {
174 self.batch_size_limit
175 }
176
177 pub fn max_batch_count<T>(&self) -> Option<u32> {
188 let limit = self.batch_size_limit?;
189 let element_size = std::mem::size_of::<T>() as u32;
190 if element_size == 0 {
191 return None;
192 }
193 let trailer_size = crate::protocol::multi_batch::trailer_total_size(element_size, 1);
195 let max_payload = limit.saturating_sub(trailer_size);
196 Some(max_payload / element_size)
197 }
198
199 pub async fn create_accounts(
220 &mut self,
221 accounts: &[Account],
222 ) -> Result<Vec<CreateAccountsResult>> {
223 let response = self.request(Operation::CreateAccounts, accounts).await?;
224 let payload = crate::protocol::multi_batch::decode(
225 &response,
226 std::mem::size_of::<CreateAccountsResult>() as u32,
227 );
228 Ok(parse_results(payload))
229 }
230
231 pub async fn create_transfers(
236 &mut self,
237 transfers: &[Transfer],
238 ) -> Result<Vec<CreateTransfersResult>> {
239 let response = self.request(Operation::CreateTransfers, transfers).await?;
240 let payload = crate::protocol::multi_batch::decode(
241 &response,
242 std::mem::size_of::<CreateTransfersResult>() as u32,
243 );
244 Ok(parse_results(payload))
245 }
246
247 pub async fn lookup_accounts(&mut self, ids: &[u128]) -> Result<Vec<Account>> {
249 let response = self.request(Operation::LookupAccounts, ids).await?;
250 let payload =
251 crate::protocol::multi_batch::decode(&response, std::mem::size_of::<Account>() as u32);
252 Ok(parse_results(payload))
253 }
254
255 pub async fn lookup_transfers(&mut self, ids: &[u128]) -> Result<Vec<Transfer>> {
257 let response = self.request(Operation::LookupTransfers, ids).await?;
258 let payload =
259 crate::protocol::multi_batch::decode(&response, std::mem::size_of::<Transfer>() as u32);
260 Ok(parse_results(payload))
261 }
262
263 pub async fn get_account_transfers(&mut self, filter: AccountFilter) -> Result<Vec<Transfer>> {
265 let response = self
266 .request(Operation::GetAccountTransfers, &[filter])
267 .await?;
268 let payload =
269 crate::protocol::multi_batch::decode(&response, std::mem::size_of::<Transfer>() as u32);
270 Ok(parse_results(payload))
271 }
272
273 pub async fn get_account_balances(
275 &mut self,
276 filter: AccountFilter,
277 ) -> Result<Vec<AccountBalance>> {
278 let response = self
279 .request(Operation::GetAccountBalances, &[filter])
280 .await?;
281 let payload = crate::protocol::multi_batch::decode(
282 &response,
283 std::mem::size_of::<AccountBalance>() as u32,
284 );
285 Ok(parse_results(payload))
286 }
287
288 pub async fn query_accounts(&mut self, filter: QueryFilter) -> Result<Vec<Account>> {
290 let response = self.request(Operation::QueryAccounts, &[filter]).await?;
291 let payload =
292 crate::protocol::multi_batch::decode(&response, std::mem::size_of::<Account>() as u32);
293 Ok(parse_results(payload))
294 }
295
296 pub async fn query_transfers(&mut self, filter: QueryFilter) -> Result<Vec<Transfer>> {
298 let response = self.request(Operation::QueryTransfers, &[filter]).await?;
299 let payload =
300 crate::protocol::multi_batch::decode(&response, std::mem::size_of::<Transfer>() as u32);
301 Ok(parse_results(payload))
302 }
303
304 pub async fn close(mut self) {
306 self.state = State::Shutdown;
307 self.driver.close().await;
308 self.buffer_pool.clear_quarantine();
309 }
310
311 async fn register(&mut self) -> Result<()> {
317 if self.state != State::Disconnected {
318 return Err(ClientError::InvalidOperation);
319 }
320
321 self.state = State::Registering;
322
323 let body = RegisterRequest::default();
325 let body_bytes = unsafe {
326 std::slice::from_raw_parts(
327 &body as *const _ as *const u8,
328 std::mem::size_of::<RegisterRequest>(),
329 )
330 };
331
332 let msg = RequestBuilder::new(self.cluster, self.id)
333 .session(0)
334 .request(0)
335 .parent(0)
336 .operation(Operation::Register)
337 .release(CLIENT_RELEASE)
338 .body(body_bytes)
339 .build();
340
341 self.parent = msg.header().checksum;
342
343 let reply = self.send_request_with_retry(msg).await?;
345
346 let body = reply.body();
348 if body.len() < std::mem::size_of::<RegisterResult>() {
349 return Err(ClientError::Protocol(ProtocolError::InvalidSize));
350 }
351
352 let result: &RegisterResult = unsafe { &*(body.as_ptr() as *const RegisterResult) };
353
354 self.batch_size_limit = Some(result.batch_size_limit);
356 self.session = reply.header().as_reply().commit;
357 self.parent = reply.header().as_reply().context;
358 self.request_number = 1;
359 self.state = State::Ready;
360
361 Ok(())
362 }
363
364 async fn request<E: Copy>(&mut self, operation: Operation, events: &[E]) -> Result<Vec<u8>> {
366 if self.state != State::Ready {
367 return Err(ClientError::NotRegistered);
368 }
369
370 let events_bytes = unsafe {
372 std::slice::from_raw_parts(
373 events.as_ptr() as *const u8,
374 std::mem::size_of_val(events),
375 )
376 };
377
378 let body_slice: &[u8] = if operation.is_multi_batch() {
380 let element_size = std::mem::size_of::<E>() as u32;
381 let trailer_size = crate::protocol::multi_batch::trailer_total_size(element_size, 1);
382 let total_size = events_bytes.len() as u32 + trailer_size;
383
384 if let Some(limit) = self.batch_size_limit {
386 if total_size > limit {
387 return Err(ClientError::RequestTooLarge {
388 size: total_size,
389 limit,
390 });
391 }
392 }
393 let encoded_size = crate::protocol::multi_batch::encode(
394 &mut self.send_buffer[..total_size as usize],
395 events_bytes,
396 element_size,
397 );
398 &self.send_buffer[..encoded_size as usize]
399 } else {
400 events_bytes
401 };
402
403 let msg = RequestBuilder::new(self.cluster, self.id)
405 .session(self.session)
406 .request(self.request_number)
407 .parent(self.parent)
408 .operation(operation)
409 .release(CLIENT_RELEASE)
410 .view(self.view)
411 .body(body_slice)
412 .build();
413
414 self.parent = msg.header().checksum;
415 self.request_number += 1;
416
417 let reply = self.send_request_with_retry(msg).await?;
419
420 let reply_header = reply.header().as_reply();
422 self.parent = reply_header.context;
423
424 if reply.header().view > self.view {
425 self.view = reply.header().view;
426 }
427
428 Ok(reply.body().to_vec())
429 }
430
431 async fn send_request_with_retry(&mut self, msg: Message) -> Result<Message> {
433 let mut timeout = self.request_timeout;
434 let expected_checksum = msg.header().checksum;
435
436 loop {
437 self.send_with_hedging(&msg).await?;
439
440 match self.wait_for_reply(expected_checksum, timeout).await {
442 Ok(reply) => return Ok(reply),
443 Err(ClientError::Timeout) => {
444 timeout = std::cmp::min(timeout * 2, self.request_timeout_max);
446 let jitter = self.rng.random_range(0..timeout.as_millis() as u64 / 4);
447 timeout += Duration::from_millis(jitter);
448 }
449 Err(e) => return Err(e),
450 }
451 }
452 }
453
454 async fn send_with_hedging(&mut self, msg: &Message) -> Result<()> {
456 let primary = (self.view % self.replica_count as u32) as usize;
457
458 self.ensure_connected(primary).await?;
460 self.driver.send(primary, msg.as_bytes()).await?;
461
462 if self.replica_count > 1 {
464 let backup_offset = self.rng.random_range(1..self.replica_count as usize);
465 let backup = (primary + backup_offset) % self.replica_count as usize;
466
467 if self.ensure_connected(backup).await.is_ok() {
468 let _ = self.driver.send(backup, msg.as_bytes()).await;
469 }
470 }
471
472 Ok(())
473 }
474
475 async fn ensure_connected(&mut self, idx: usize) -> Result<()> {
477 if !self.driver.is_connected(idx) {
478 self.driver.connect(idx).await?;
479 }
480 Ok(())
481 }
482
483 async fn wait_for_reply(
485 &mut self,
486 expected_checksum: u128,
487 timeout: Duration,
488 ) -> Result<Message> {
489 let start = Instant::now();
490 let primary = (self.view % self.replica_count as u32) as usize;
491
492 loop {
493 if start.elapsed() >= timeout {
494 return Err(ClientError::Timeout);
495 }
496
497 let buf = self
499 .buffer_pool
500 .acquire()
501 .ok_or(ClientError::Connection("buffer pool exhausted".into()))?;
502
503 let buf = match self.driver.recv(primary, buf).await {
505 Ok(b) => b,
506 Err(e) => {
507 self.driver.disconnect(primary).await;
509 return Err(e);
510 }
511 };
512
513 match self.try_parse_reply(&buf, expected_checksum) {
515 Ok(msg) => {
516 self.buffer_pool.release(buf);
517 return Ok(msg);
518 }
519 Err(ParseError::NeedMoreData) => {
520 self.buffer_pool.release(buf);
522 continue;
523 }
524 Err(ParseError::WrongReply) => {
525 self.buffer_pool.release(buf);
526 continue;
527 }
528 Err(ParseError::Evicted(reason)) => {
529 self.buffer_pool.release(buf);
530 return Err(ClientError::Evicted(reason));
531 }
532 Err(ParseError::Protocol(e)) => {
533 self.buffer_pool.release(buf);
534 self.driver.disconnect(primary).await;
535 return Err(ClientError::Protocol(e));
536 }
537 }
538 }
539 }
540
541 fn try_parse_reply(
543 &self,
544 buf: &OwnedBuf,
545 expected_checksum: u128,
546 ) -> std::result::Result<Message, ParseError> {
547 let data = buf.as_slice();
548
549 if data.len() < HEADER_SIZE as usize {
550 return Err(ParseError::NeedMoreData);
551 }
552
553 let header_bytes: &[u8; HEADER_SIZE as usize] = data[..HEADER_SIZE as usize]
554 .try_into()
555 .map_err(|_| ParseError::Protocol(ProtocolError::InvalidHeader))?;
556 let header = Header::from_bytes(header_bytes);
557
558 if !header.valid_checksum() {
559 return Err(ParseError::Protocol(ProtocolError::InvalidHeaderChecksum));
560 }
561
562 if header.command != Command::Reply as u8 {
563 if header.command == Command::Eviction as u8 {
564 let reason = header.as_eviction().reason;
565 return Err(ParseError::Evicted(
566 reason
567 .try_into()
568 .unwrap_or(crate::protocol::header::EvictionReason::NoSession),
569 ));
570 }
571 return Err(ParseError::Protocol(ProtocolError::UnexpectedReply));
572 }
573
574 let total_size = header.size as usize;
575 if data.len() < total_size {
576 return Err(ParseError::NeedMoreData);
577 }
578
579 let reply_header = header.as_reply();
580 if reply_header.request_checksum != expected_checksum {
581 return Err(ParseError::WrongReply);
582 }
583 if reply_header.client != self.id {
584 return Err(ParseError::WrongReply);
585 }
586
587 let body_data = &data[HEADER_SIZE as usize..total_size];
588 if !header.valid_checksum_body(body_data) {
589 return Err(ParseError::Protocol(ProtocolError::InvalidBodyChecksum));
590 }
591
592 let msg_data = data[..total_size].to_vec();
593 let msg = Message::from_bytes(msg_data)
594 .ok_or(ParseError::Protocol(ProtocolError::InvalidHeader))?;
595
596 Ok(msg)
597 }
598}
599
600enum ParseError {
602 NeedMoreData,
603 WrongReply,
604 Evicted(crate::protocol::header::EvictionReason),
605 Protocol(ProtocolError),
606}
607
608fn parse_results<R: Copy>(data: &[u8]) -> Vec<R> {
610 let count = data.len() / std::mem::size_of::<R>();
611 if count == 0 {
612 return Vec::new();
613 }
614 let ptr = data.as_ptr() as *const R;
615 unsafe { std::slice::from_raw_parts(ptr, count) }.to_vec()
616}
617
618pub struct ClientBuilder {
635 cluster: u128,
636 addresses: Vec<SocketAddr>,
637 connect_timeout: Duration,
638 request_timeout: Duration,
639 request_timeout_max: Duration,
640}
641
642impl ClientBuilder {
643 pub fn new() -> Self {
645 Self {
646 cluster: 0,
647 addresses: Vec::new(),
648 connect_timeout: Duration::from_secs(5),
649 request_timeout: Duration::from_millis(500),
650 request_timeout_max: Duration::from_secs(30),
651 }
652 }
653
654 pub fn cluster(mut self, id: u128) -> Self {
656 self.cluster = id;
657 self
658 }
659
660 pub fn addresses(mut self, addrs: &str) -> Result<Self> {
662 if addrs.trim().is_empty() {
663 return Err(ClientError::Connection("no addresses provided".into()));
664 }
665
666 self.addresses = addrs
667 .split(',')
668 .map(|s| {
669 s.trim().parse().map_err(|e| {
670 ClientError::Connection(format!("invalid address '{}': {}", s.trim(), e))
671 })
672 })
673 .collect::<Result<Vec<_>>>()?;
674
675 Ok(self)
676 }
677
678 pub fn addresses_vec(mut self, addrs: Vec<SocketAddr>) -> Self {
680 self.addresses = addrs;
681 self
682 }
683
684 pub fn connect_timeout(mut self, timeout: Duration) -> Self {
686 self.connect_timeout = timeout;
687 self
688 }
689
690 pub fn request_timeout(mut self, timeout: Duration) -> Self {
692 self.request_timeout = timeout;
693 self
694 }
695
696 pub fn request_timeout_max(mut self, timeout: Duration) -> Self {
698 self.request_timeout_max = timeout;
699 self
700 }
701
702 pub async fn build(self) -> Result<Client> {
706 if self.addresses.is_empty() {
707 return Err(ClientError::Connection("no addresses provided".into()));
708 }
709
710 let id: u128 = rand::random();
711 if id == 0 {
712 return Err(ClientError::Protocol(ProtocolError::InvalidHeader));
713 }
714
715 let replica_count = self.addresses.len() as u8;
716 let driver = Driver::new(self.addresses, self.connect_timeout);
717
718 let buffer_count = replica_count as usize + 2;
719 let buffer_pool = BufferPool::new(buffer_count, MESSAGE_SIZE_MAX as usize);
720
721 let mut client = Client {
722 id,
723 cluster: self.cluster,
724 replica_count,
725 driver,
726 state: State::Disconnected,
727 view: 0,
728 session: 0,
729 request_number: 0,
730 parent: 0,
731 batch_size_limit: None,
732 rng: rand::rngs::StdRng::from_os_rng(),
733 send_buffer: vec![0u8; MESSAGE_SIZE_MAX as usize],
734 buffer_pool,
735 request_timeout: self.request_timeout,
736 request_timeout_max: self.request_timeout_max,
737 };
738
739 client.register().await?;
741
742 Ok(client)
743 }
744}
745
746impl Default for ClientBuilder {
747 fn default() -> Self {
748 Self::new()
749 }
750}
751
752#[cfg(test)]
753mod tests {
754 use super::*;
755
756 #[test]
757 fn test_builder_defaults() {
758 let builder = ClientBuilder::new();
759 assert_eq!(builder.cluster, 0);
760 assert!(builder.addresses.is_empty());
761 assert_eq!(builder.connect_timeout, Duration::from_secs(5));
762 }
763
764 #[test]
765 fn test_builder_addresses_empty() {
766 let result = ClientBuilder::new().addresses("");
767 assert!(result.is_err());
768 }
769
770 #[test]
771 fn test_builder_addresses_invalid() {
772 let result = ClientBuilder::new().addresses("not-an-address");
773 assert!(result.is_err());
774 }
775
776 #[test]
777 fn test_builder_addresses_valid() {
778 let builder = ClientBuilder::new()
779 .addresses("127.0.0.1:3000,127.0.0.1:3001")
780 .unwrap();
781 assert_eq!(builder.addresses.len(), 2);
782 }
783
784 #[test]
785 fn test_parse_results_empty() {
786 let data: &[u8] = &[];
787 let results: Vec<u32> = parse_results(data);
788 assert!(results.is_empty());
789 }
790
791 #[test]
792 fn test_parse_results() {
793 let data: [u8; 8] = [1, 0, 0, 0, 2, 0, 0, 0];
794 let results: Vec<u32> = parse_results(&data);
795 assert_eq!(results, vec![1, 2]);
796 }
797}