hstreamdb 0.2.2

Rust client library for HStreamDB
Documentation
use std::io::Write;

use flate2::write::GzDecoder;
use hstreamdb_pb::h_stream_api_client::HStreamApiClient;
use hstreamdb_pb::{
    BatchHStreamRecords, HStreamRecord, LookupShardRequest, ReceivedRecord, RecordId, Shard,
};
use md5::{Digest, Md5};
use num_bigint::BigInt;
use num_traits::Num;
use prost::Message;
use tonic::transport::Channel;

use crate::common::{self, PartitionKey, ShardId};
use crate::{format_url, Error};

pub async fn lookup_shard(
    channel: &mut HStreamApiClient<Channel>,
    url_scheme: &str,
    shard_id: ShardId,
    shard_url: Option<&String>,
) -> common::Result<String> {
    match shard_url {
        Some(url) => Ok(url.to_string()),
        None => {
            let server_node = channel
                .lookup_shard(LookupShardRequest { shard_id })
                .await?
                .into_inner()
                .server_node
                .ok_or_else(|| Error::PBUnwrapError("server_node".to_string()))?;
            let server_node = format_url!(url_scheme, server_node.host, server_node.port);
            Ok(server_node)
        }
    }
}

pub fn partition_key_to_shard_id(
    shards: &[Shard],
    partition_key: PartitionKey,
) -> common::Result<ShardId> {
    let hash = {
        let mut hasher = Md5::new();
        hasher.update(partition_key);
        let hash = hasher.finalize();
        let hash = format!("{hash:x}");
        BigInt::from_str_radix(&hash, 16).map_err(|err| {
            common::Error::PartitionKeyError(common::PartitionKeyError::ParseBigIntError(err))
        })?
    };

    for shard in shards {
        let start = BigInt::from_str_radix(&shard.start_hash_range_key, 10).map_err(|err| {
            common::Error::PartitionKeyError(common::PartitionKeyError::ParseBigIntError(err))
        })?;
        let end = BigInt::from_str_radix(&shard.end_hash_range_key, 10).map_err(|err| {
            common::Error::PartitionKeyError(common::PartitionKeyError::ParseBigIntError(err))
        })?;

        if start <= hash && hash <= end {
            return Ok(shard.shard_id);
        }
    }

    Err(common::Error::PartitionKeyError(
        common::PartitionKeyError::NoMatch,
    ))
}

pub(crate) fn decode_received_records(
    received_records: ReceivedRecord,
) -> common::Result<Vec<(RecordId, HStreamRecord)>> {
    let is_empty = received_records.record_ids.is_empty();
    let record_ids = received_records.record_ids;
    let received_records = received_records.record;
    if is_empty {
        Ok(Vec::new())
    } else {
        match received_records {
            None => Err(common::Error::PBUnwrapError(
                "received_records.record".to_string(),
            )),
            Some(record) => {
                let compression_type = record.compression_type();
                let payload = {
                    let payload = record.payload;
                    match compression_type {
                        hstreamdb_pb::CompressionType::None => Ok(payload),
                        hstreamdb_pb::CompressionType::Gzip => {
                            let mut decoder = GzDecoder::new(Vec::new());
                            decoder
                                .write_all(&payload)
                                .map_err(common::Error::DecompressError)?;
                            decoder.finish().map_err(common::Error::DecompressError)
                        }
                        hstreamdb_pb::CompressionType::Zstd => zstd::decode_all(payload.as_slice())
                            .map_err(common::Error::DecompressError),
                    }
                }?;

                let records = BatchHStreamRecords::decode(payload.as_slice())
                    .map_err(common::Error::PBDecodeError)?
                    .records;
                Ok(record_ids.into_iter().zip(records).collect())
            }
        }
    }
}

#[doc(hidden)]
#[macro_export]
macro_rules! format_url {
    ($scheme:expr, $host:expr, $port:expr) => {
        format!("{}://{}:{}", $scheme, $host, $port)
    };
    ($scheme:expr, $server_node:expr) => {
        format_url!($scheme, $server_node.host, $server_node.port)
    };
}

#[cfg(test)]
mod tests {
    use std::collections::HashMap;
    use std::env;

    use hstreamdb_pb::{ListShardsRequest, Stream};
    use hstreamdb_test_utils::rand_alphanumeric;

    use super::partition_key_to_shard_id;
    use crate::client::Client;
    use crate::ChannelProviderSettings;

    #[tokio::test(flavor = "multi_thread")]
    async fn test_partition_key_to_shard_id() {
        let addr = env::var("TEST_SERVER_ADDR").unwrap();
        let client = Client::new(addr, ChannelProviderSettings::builder().build())
            .await
            .unwrap();

        let stream_name = rand_alphanumeric(20);

        client
            .create_stream(Stream {
                stream_name: stream_name.clone(),
                replication_factor: 1,
                backlog_duration: 10 * 60,
                shard_count: 200,
                creation_time: None,
            })
            .await
            .unwrap();

        let shards = client
            .channels
            .channel()
            .await
            .list_shards(ListShardsRequest { stream_name })
            .await
            .unwrap()
            .into_inner()
            .shards;

        let mut result = HashMap::new();
        for _ in 0..400 {
            let shard_id = partition_key_to_shard_id(&shards, rand_alphanumeric(20)).unwrap();
            result.insert(shard_id, ());
        }
        println!("result.len() = {}", result.len());
        assert!(result.len() > 100)
    }
}