1use crate::config::{GrpcAuth, GrpcStreamConfig, MetadataEntry, RpcKind};
4use async_trait::async_trait;
5use base64::Engine as _;
6use faucet_core::{AuthSpec, Credential, FaucetError, SharedAuthProvider, StreamPage};
7use futures_core::Stream;
8use prost::Message;
9use prost::bytes::Bytes;
10use prost_reflect::{DescriptorPool, DynamicMessage, MessageDescriptor, SerializeOptions};
11use serde_json::Value;
12use std::pin::Pin;
13use std::time::Duration;
14use tonic::codec::{Codec, DecodeBuf, Decoder, EncodeBuf, Encoder};
15use tonic::transport::Channel;
16
17pub struct GrpcStream {
20 config: GrpcStreamConfig,
21 pool: DescriptorPool,
22 auth_provider: Option<SharedAuthProvider>,
25}
26
27impl GrpcStream {
28 pub fn new(config: GrpcStreamConfig) -> Result<Self, FaucetError> {
30 if config.reconnect_initial_backoff.is_zero() {
33 return Err(FaucetError::Config(
34 "grpc reconnect_initial_backoff must be > 0 (a zero backoff busy-spins reconnects)"
35 .into(),
36 ));
37 }
38 let descriptor_bytes = std::fs::read(&config.descriptor_set_path).map_err(|e| {
39 FaucetError::Config(format!(
40 "failed to read descriptor set at {}: {e}",
41 config.descriptor_set_path.display()
42 ))
43 })?;
44
45 let pool = DescriptorPool::decode(Bytes::from(descriptor_bytes))
46 .map_err(|e| FaucetError::Config(format!("failed to parse FileDescriptorSet: {e}")))?;
47
48 Ok(Self {
49 config,
50 pool,
51 auth_provider: None,
52 })
53 }
54
55 pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
62 self.auth_provider = Some(provider);
63 self
64 }
65
66 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
68 self.fetch_resolved(
69 &self.config.endpoint,
70 &self.config.service_name,
71 &self.config.method_name,
72 &self.config.request,
73 )
74 .await
75 }
76
77 async fn fetch_resolved(
79 &self,
80 endpoint: &str,
81 service_name: &str,
82 method_name: &str,
83 request: &Value,
84 ) -> Result<Vec<Value>, FaucetError> {
85 let (output_desc, request_bytes) = self.prepare_call(service_name, method_name, request)?;
86 let path = parse_method_path(service_name, method_name)?;
87
88 match self.config.rpc_kind {
89 RpcKind::Unary => {
90 let channel = self.connect_channel(endpoint).await?;
91 let mut grpc_client = self.configure_client(tonic::client::Grpc::new(channel));
92 grpc_client
93 .ready()
94 .await
95 .map_err(|e| FaucetError::Config(format!("gRPC channel not ready: {e}")))?;
96
97 let codec = DynamicCodec::new(output_desc);
98 let request = self.build_grpc_request(request_bytes).await?;
99
100 let response: tonic::Response<DynamicMessage> = grpc_client
101 .unary(request, path, codec)
102 .await
103 .map_err(|e| FaucetError::Source(format!("gRPC unary call failed: {e}")))?;
104
105 let resp_msg = response.into_inner();
106 let records =
107 serialize_and_extract(&resp_msg, self.config.records_path.as_deref())?;
108 tracing::info!(records = records.len(), "gRPC unary fetch complete");
109 Ok(records)
110 }
111 RpcKind::ServerStreaming => {
112 self.fetch_server_streaming_collect(endpoint, path, output_desc, request_bytes)
113 .await
114 }
115 }
116 }
117
118 async fn fetch_server_streaming_collect(
124 &self,
125 endpoint: &str,
126 path: tonic::codegen::http::uri::PathAndQuery,
127 output_desc: MessageDescriptor,
128 request_bytes: Vec<u8>,
129 ) -> Result<Vec<Value>, FaucetError> {
130 let max_messages = self.config.max_messages.unwrap_or(usize::MAX);
131 let mut all: Vec<Value> = Vec::new();
132 let mut messages_seen: usize = 0;
133 let mut attempt: u32 = 0;
134 let mut backoff = self.config.reconnect_initial_backoff;
135 let max_backoff = self.config.reconnect_max_backoff;
136
137 loop {
138 match self
139 .drive_server_streaming_once(
140 endpoint,
141 path.clone(),
142 output_desc.clone(),
143 request_bytes.clone(),
144 max_messages,
145 messages_seen,
146 |records| {
147 all.extend(records);
148 Ok(())
149 },
150 )
151 .await
152 {
153 Ok(consumed) => {
154 tracing::info!(
155 records = all.len(),
156 messages = messages_seen + consumed,
157 "gRPC server-streaming fetch complete"
158 );
159 return Ok(all);
160 }
161 Err(StreamOutcome::Done(consumed)) => {
162 messages_seen += consumed;
163 tracing::info!(
164 records = all.len(),
165 messages = messages_seen,
166 "gRPC server-streaming fetch complete (max_messages reached)"
167 );
168 return Ok(all);
169 }
170 Err(StreamOutcome::Transient { consumed, error }) => {
171 messages_seen += consumed;
172 if self.config.terminate_on_error {
173 return Err(error);
174 }
175 if consumed > 0 {
183 attempt = 0;
184 backoff = self.config.reconnect_initial_backoff;
185 }
186 if let Some(max_attempts) = self.config.reconnect_max_attempts
187 && attempt >= max_attempts
188 {
189 return Err(FaucetError::Source(format!(
190 "gRPC server-streaming exceeded reconnect_max_attempts={max_attempts}: {error}"
191 )));
192 }
193 attempt += 1;
194 tracing::warn!(
195 attempt,
196 backoff_ms = backoff.as_millis() as u64,
197 error = %error,
198 "gRPC server-streaming transient error, reconnecting"
199 );
200 tokio::time::sleep(backoff).await;
201 backoff = next_backoff(backoff, max_backoff);
202 }
203 }
204 }
205 }
206
207 #[allow(clippy::too_many_arguments)]
215 async fn drive_server_streaming_once<F>(
216 &self,
217 endpoint: &str,
218 path: tonic::codegen::http::uri::PathAndQuery,
219 output_desc: MessageDescriptor,
220 request_bytes: Vec<u8>,
221 max_messages: usize,
222 already_seen: usize,
223 mut on_records: F,
224 ) -> Result<usize, StreamOutcome>
225 where
226 F: FnMut(Vec<Value>) -> Result<(), FaucetError>,
227 {
228 let channel = match self.connect_channel(endpoint).await {
229 Ok(c) => c,
230 Err(e) => {
231 return Err(StreamOutcome::Transient {
232 consumed: 0,
233 error: e,
234 });
235 }
236 };
237
238 let mut grpc_client = self.configure_client(tonic::client::Grpc::new(channel));
239 if let Err(e) = grpc_client.ready().await {
240 return Err(StreamOutcome::Transient {
241 consumed: 0,
242 error: FaucetError::Source(format!("gRPC channel not ready: {e}")),
243 });
244 }
245
246 let codec = DynamicCodec::new(output_desc);
247 let request = match self.build_grpc_request(request_bytes).await {
248 Ok(r) => r,
249 Err(e) => {
250 return Err(StreamOutcome::Transient {
252 consumed: 0,
253 error: e,
254 });
255 }
256 };
257
258 let response = match grpc_client.server_streaming(request, path, codec).await {
259 Ok(r) => r,
260 Err(status) => {
261 return Err(StreamOutcome::Transient {
262 consumed: 0,
263 error: FaucetError::Source(format!(
264 "gRPC server-streaming start failed: {status}"
265 )),
266 });
267 }
268 };
269
270 let mut streaming = response.into_inner();
271 let records_path = self.config.records_path.as_deref();
272 let skip = if self.config.reconnect_replay_from_start {
277 already_seen
278 } else {
279 0
280 };
281 let mut position: usize = 0; let mut emitted: usize = 0; loop {
285 if already_seen + emitted >= max_messages {
286 return Err(StreamOutcome::Done(emitted));
287 }
288 match streaming.message().await {
289 Ok(Some(msg)) => {
290 position += 1;
291 if position <= skip {
292 continue;
294 }
295 let records = match serialize_and_extract(&msg, records_path) {
296 Ok(r) => r,
297 Err(e) => {
298 return Err(StreamOutcome::Transient {
299 consumed: emitted,
300 error: e,
301 });
302 }
303 };
304 if let Err(e) = on_records(records) {
305 return Err(StreamOutcome::Transient {
306 consumed: emitted,
307 error: e,
308 });
309 }
310 emitted += 1;
311 }
312 Ok(None) => {
313 return Ok(emitted);
314 }
315 Err(status) => {
316 return Err(StreamOutcome::Transient {
317 consumed: emitted,
318 error: FaucetError::Source(format!(
319 "gRPC server-streaming recv failed: {status}"
320 )),
321 });
322 }
323 }
324 }
325 }
326
327 fn prepare_call(
329 &self,
330 service_name: &str,
331 method_name: &str,
332 request: &Value,
333 ) -> Result<(MessageDescriptor, Vec<u8>), FaucetError> {
334 let service = self.pool.get_service_by_name(service_name).ok_or_else(|| {
335 FaucetError::Config(format!(
336 "service '{service_name}' not found in descriptor set",
337 ))
338 })?;
339
340 let method = service
341 .methods()
342 .find(|m| m.name() == method_name)
343 .ok_or_else(|| {
344 FaucetError::Config(format!(
345 "method '{method_name}' not found in service '{service_name}'",
346 ))
347 })?;
348
349 let input_desc = method.input();
350 let request_msg = DynamicMessage::deserialize(input_desc, request)
351 .map_err(|e| FaucetError::Config(format!("failed to build request message: {e}")))?;
352 let request_bytes = request_msg.encode_to_vec();
353
354 Ok((method.output(), request_bytes))
355 }
356
357 async fn connect_channel(&self, endpoint: &str) -> Result<Channel, FaucetError> {
359 let use_tls = self
360 .config
361 .tls
362 .unwrap_or_else(|| endpoint.starts_with("https"));
363
364 let channel_endpoint = Channel::from_shared(endpoint.to_string())
365 .map_err(|e| FaucetError::Url(format!("invalid gRPC endpoint: {e}")))?;
366
367 let channel = if use_tls {
368 channel_endpoint
369 .tls_config(tonic::transport::ClientTlsConfig::new())
370 .map_err(|e| FaucetError::Config(format!("TLS config failed: {e}")))?
371 .connect()
372 .await
373 .map_err(|e| FaucetError::Source(format!("gRPC connect failed: {e}")))?
374 } else {
375 channel_endpoint
376 .connect()
377 .await
378 .map_err(|e| FaucetError::Source(format!("gRPC connect failed: {e}")))?
379 };
380
381 Ok(channel)
382 }
383
384 fn configure_client(
387 &self,
388 mut client: tonic::client::Grpc<Channel>,
389 ) -> tonic::client::Grpc<Channel> {
390 if let Some(n) = self.config.max_decoding_message_size {
391 client = client.max_decoding_message_size(n);
392 }
393 if let Some(n) = self.config.max_encoding_message_size {
394 client = client.max_encoding_message_size(n);
395 }
396 client
397 }
398
399 async fn build_grpc_request(
406 &self,
407 request_bytes: Vec<u8>,
408 ) -> Result<tonic::Request<Vec<u8>>, FaucetError> {
409 let effective = if let Some(provider) = &self.auth_provider {
410 credential_to_auth(provider.credential().await?)
411 } else {
412 match &self.config.auth {
413 AuthSpec::Inline(a) => a.clone(),
414 AuthSpec::Reference(r) => {
415 return Err(FaucetError::Auth(format!(
416 "auth references provider '{}' but no provider was supplied; \
417 set one via the CLI `auth:` catalog or `with_auth_provider`",
418 r.name
419 )));
420 }
421 }
422 };
423
424 let mut request = tonic::Request::new(request_bytes);
425 apply_grpc_auth(&effective, &mut request)?;
426 Ok(request)
427 }
428}
429
430#[async_trait]
431impl faucet_core::Source for GrpcStream {
432 async fn fetch_with_context(
433 &self,
434 context: &std::collections::HashMap<String, serde_json::Value>,
435 ) -> Result<Vec<Value>, FaucetError> {
436 if context.is_empty() {
437 return GrpcStream::fetch_all(self).await;
438 }
439
440 let endpoint = faucet_core::util::substitute_context(&self.config.endpoint, context);
441 let service_name =
442 faucet_core::util::substitute_context(&self.config.service_name, context);
443 let method_name = faucet_core::util::substitute_context(&self.config.method_name, context);
444
445 let request = {
446 let s = serde_json::to_string(&self.config.request)
447 .map_err(|e| FaucetError::Config(format!("failed to serialize request: {e}")))?;
448 let s = faucet_core::util::substitute_context_json(&s, context);
449 serde_json::from_str(&s).map_err(|e| {
450 FaucetError::Config(format!("failed to parse substituted request: {e}"))
451 })?
452 };
453
454 self.fetch_resolved(&endpoint, &service_name, &method_name, &request)
455 .await
456 }
457
458 fn stream_pages<'a>(
476 &'a self,
477 context: &'a std::collections::HashMap<String, Value>,
478 batch_size: usize,
479 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
480 match self.config.rpc_kind {
481 RpcKind::Unary => {
482 self.default_stream_pages(context, batch_size)
485 }
486 RpcKind::ServerStreaming => self.server_streaming_pages(context),
487 }
488 }
489
490 fn config_schema(&self) -> serde_json::Value {
491 serde_json::to_value(faucet_core::schema_for!(GrpcStreamConfig))
492 .expect("schema serialization")
493 }
494}
495
496impl GrpcStream {
497 fn default_stream_pages<'a>(
501 &'a self,
502 context: &'a std::collections::HashMap<String, Value>,
503 batch_size: usize,
504 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
505 use faucet_core::Source;
506 Box::pin(async_stream::try_stream! {
507 let (records, bookmark) = Source::fetch_with_context_incremental(self, context).await?;
508 let total = records.len();
509 let chunk = if batch_size == 0 { usize::MAX } else { batch_size };
510
511 if total == 0 {
512 if bookmark.is_some() {
513 yield StreamPage { records: Vec::new(), bookmark };
514 }
515 return;
516 }
517
518 let mut iter = records.into_iter();
519 let mut consumed = 0usize;
520 loop {
521 let batch: Vec<Value> = iter.by_ref().take(chunk).collect();
522 if batch.is_empty() {
523 break;
524 }
525 consumed += batch.len();
526 let page_bookmark = if consumed >= total { bookmark.clone() } else { None };
527 yield StreamPage { records: batch, bookmark: page_bookmark };
528 }
529 })
530 }
531
532 fn server_streaming_pages<'a>(
536 &'a self,
537 context: &'a std::collections::HashMap<String, Value>,
538 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
539 let batch_size = self.config.batch_size;
540 let page_chunk = if batch_size == 0 {
541 usize::MAX
542 } else {
543 batch_size
544 };
545 let initial_capacity = if batch_size == 0 { 1024 } else { batch_size };
546 let max_messages = self.config.max_messages.unwrap_or(usize::MAX);
547 let terminate_on_error = self.config.terminate_on_error;
548 let reconnect_max_attempts = self.config.reconnect_max_attempts;
549 let reconnect_initial_backoff = self.config.reconnect_initial_backoff;
550 let mut backoff = reconnect_initial_backoff;
551 let max_backoff = self.config.reconnect_max_backoff;
552
553 Box::pin(async_stream::try_stream! {
554 let endpoint = if context.is_empty() {
558 self.config.endpoint.clone()
559 } else {
560 faucet_core::util::substitute_context(&self.config.endpoint, context)
561 };
562 let service_name = if context.is_empty() {
563 self.config.service_name.clone()
564 } else {
565 faucet_core::util::substitute_context(&self.config.service_name, context)
566 };
567 let method_name = if context.is_empty() {
568 self.config.method_name.clone()
569 } else {
570 faucet_core::util::substitute_context(&self.config.method_name, context)
571 };
572 let request: Value = if context.is_empty() {
573 self.config.request.clone()
574 } else {
575 let s = serde_json::to_string(&self.config.request)
576 .map_err(|e| FaucetError::Config(format!("failed to serialize request: {e}")))?;
577 let s = faucet_core::util::substitute_context_json(&s, context);
578 serde_json::from_str(&s).map_err(|e| FaucetError::Config(format!(
579 "failed to parse substituted request: {e}"
580 )))?
581 };
582
583 let (output_desc, request_bytes) =
584 self.prepare_call(&service_name, &method_name, &request)?;
585 let path = parse_method_path(&service_name, &method_name)?;
586
587 let mut buffer: Vec<Value> = Vec::with_capacity(initial_capacity);
588 let mut messages_seen: usize = 0;
589 let mut attempt: u32 = 0;
590
591 'reconnect: loop {
592 let outcome = self.drive_server_streaming_once(
596 &endpoint,
597 path.clone(),
598 output_desc.clone(),
599 request_bytes.clone(),
600 max_messages,
601 messages_seen,
602 |records| {
603 buffer.extend(records);
604 Ok(())
605 },
606 ).await;
607
608 while buffer.len() >= page_chunk {
612 let drained: Vec<Value> = buffer.drain(..page_chunk).collect();
613 yield StreamPage { records: drained, bookmark: None };
614 }
615
616 match outcome {
617 Ok(consumed) => {
618 messages_seen += consumed;
619 if !buffer.is_empty() {
622 let final_records = std::mem::take(&mut buffer);
623 yield StreamPage { records: final_records, bookmark: None };
624 }
625 tracing::info!(
626 messages = messages_seen,
627 "gRPC server-streaming complete"
628 );
629 break 'reconnect;
630 }
631 Err(StreamOutcome::Done(consumed)) => {
632 messages_seen += consumed;
633 if !buffer.is_empty() {
634 let final_records = std::mem::take(&mut buffer);
635 yield StreamPage { records: final_records, bookmark: None };
636 }
637 tracing::info!(
638 messages = messages_seen,
639 "gRPC server-streaming complete (max_messages reached)"
640 );
641 break 'reconnect;
642 }
643 Err(StreamOutcome::Transient { consumed, error }) => {
644 messages_seen += consumed;
645 if terminate_on_error {
646 Err(error)?;
652 return;
653 }
654 if consumed > 0 {
659 attempt = 0;
660 backoff = reconnect_initial_backoff;
661 }
662 if let Some(max_attempts) = reconnect_max_attempts
663 && attempt >= max_attempts
664 {
665 let final_err = FaucetError::Source(format!(
666 "gRPC server-streaming exceeded reconnect_max_attempts={max_attempts}: {error}"
667 ));
668 Err(final_err)?;
669 return;
670 }
671 attempt += 1;
672 tracing::warn!(
673 attempt,
674 backoff_ms = backoff.as_millis() as u64,
675 error = %error,
676 "gRPC server-streaming transient error, reconnecting"
677 );
678 tokio::time::sleep(backoff).await;
679 backoff = next_backoff(backoff, max_backoff);
680 }
681 }
682 }
683 })
684 }
685}
686
687fn credential_to_auth(cred: Credential) -> GrpcAuth {
695 match cred {
696 Credential::Bearer(token) => GrpcAuth::Bearer { token },
697 Credential::Token(token) => GrpcAuth::Metadata {
698 entries: vec![MetadataEntry {
699 key: "authorization".into(),
700 value: token,
701 }],
702 },
703 Credential::Header { name, value } => GrpcAuth::Metadata {
704 entries: vec![MetadataEntry { key: name, value }],
705 },
706 Credential::Basic { username, password } => GrpcAuth::Metadata {
707 entries: vec![MetadataEntry {
708 key: "authorization".into(),
709 value: format!(
710 "Basic {}",
711 base64::engine::general_purpose::STANDARD
712 .encode(format!("{username}:{password}"))
713 ),
714 }],
715 },
716 }
717}
718
719fn apply_grpc_auth(
721 auth: &GrpcAuth,
722 request: &mut tonic::Request<Vec<u8>>,
723) -> Result<(), FaucetError> {
724 match auth {
725 GrpcAuth::None => {}
726 GrpcAuth::Bearer { token } => {
727 let val: tonic::metadata::MetadataValue<tonic::metadata::Ascii> =
728 format!("Bearer {token}")
729 .parse()
730 .map_err(|e| FaucetError::Auth(format!("invalid bearer token: {e}")))?;
731 request.metadata_mut().insert("authorization", val);
732 }
733 GrpcAuth::Metadata { entries } => {
734 for entry in entries {
735 let val: tonic::metadata::MetadataValue<tonic::metadata::Ascii> = entry
736 .value
737 .parse()
738 .map_err(|e| FaucetError::Auth(format!("invalid metadata value: {e}")))?;
739 let key: tonic::metadata::MetadataKey<tonic::metadata::Ascii> =
740 entry
741 .key
742 .parse()
743 .map_err(|e| FaucetError::Auth(format!("invalid metadata key: {e}")))?;
744 request.metadata_mut().insert(key, val);
745 }
746 }
747 }
748 Ok(())
749}
750
751fn parse_method_path(
752 service_name: &str,
753 method_name: &str,
754) -> Result<tonic::codegen::http::uri::PathAndQuery, FaucetError> {
755 let full_method = format!("/{service_name}/{method_name}");
756 tonic::codegen::http::uri::PathAndQuery::from_maybe_shared(full_method)
757 .map_err(|e| FaucetError::Url(format!("invalid method path: {e}")))
758}
759
760fn serialize_and_extract(
761 msg: &DynamicMessage,
762 records_path: Option<&str>,
763) -> Result<Vec<Value>, FaucetError> {
764 let serialize_opts = SerializeOptions::new().stringify_64_bit_integers(false);
765 let json_value = msg
766 .serialize_with_options(serde_json::value::Serializer, &serialize_opts)
767 .map_err(|e| {
768 FaucetError::Transform(format!("failed to serialize gRPC response to JSON: {e}"))
769 })?;
770 faucet_core::util::extract_records(&json_value, records_path)
771}
772
773fn next_backoff(current: Duration, cap: Duration) -> Duration {
774 current.saturating_mul(2).min(cap)
775}
776
777enum StreamOutcome {
781 Done(usize),
782 Transient { consumed: usize, error: FaucetError },
783}
784
785struct DynamicCodec {
789 output_desc: prost_reflect::MessageDescriptor,
790}
791
792impl DynamicCodec {
793 fn new(output_desc: prost_reflect::MessageDescriptor) -> Self {
794 Self { output_desc }
795 }
796}
797
798impl Codec for DynamicCodec {
799 type Encode = Vec<u8>;
800 type Decode = DynamicMessage;
801 type Encoder = RawEncoder;
802 type Decoder = DynamicDecoder;
803
804 fn encoder(&mut self) -> Self::Encoder {
805 RawEncoder
806 }
807
808 fn decoder(&mut self) -> Self::Decoder {
809 DynamicDecoder {
810 desc: self.output_desc.clone(),
811 }
812 }
813}
814
815struct RawEncoder;
816
817impl Encoder for RawEncoder {
818 type Item = Vec<u8>;
819 type Error = tonic::Status;
820
821 fn encode(&mut self, item: Self::Item, buf: &mut EncodeBuf<'_>) -> Result<(), Self::Error> {
822 use prost::bytes::BufMut;
823 buf.put_slice(&item);
824 Ok(())
825 }
826}
827
828struct DynamicDecoder {
829 desc: prost_reflect::MessageDescriptor,
830}
831
832impl Decoder for DynamicDecoder {
833 type Item = DynamicMessage;
834 type Error = tonic::Status;
835
836 fn decode(&mut self, buf: &mut DecodeBuf<'_>) -> Result<Option<Self::Item>, Self::Error> {
837 use prost::bytes::Buf;
838 if !buf.has_remaining() {
839 return Ok(None);
840 }
841 let bytes = buf.copy_to_bytes(buf.remaining());
842 let msg = DynamicMessage::decode(self.desc.clone(), bytes)
843 .map_err(|e| tonic::Status::internal(format!("protobuf decode error: {e}")))?;
844 Ok(Some(msg))
845 }
846}
847
848#[cfg(test)]
849mod tests {
850 use super::*;
851 use faucet_core::{AuthProvider, AuthReference, AuthSpec, Credential, FaucetError};
852 use std::sync::Arc;
853
854 #[test]
857 fn next_backoff_doubles_up_to_cap() {
858 let cap = Duration::from_secs(30);
859 let a = Duration::from_secs(1);
860 let b = next_backoff(a, cap);
861 let c = next_backoff(b, cap);
862 let d = next_backoff(c, cap);
863 let e = next_backoff(d, cap);
864 let f = next_backoff(e, cap);
865 let g = next_backoff(f, cap);
866 assert_eq!(b, Duration::from_secs(2));
867 assert_eq!(c, Duration::from_secs(4));
868 assert_eq!(d, Duration::from_secs(8));
869 assert_eq!(e, Duration::from_secs(16));
870 assert_eq!(f, Duration::from_secs(30));
871 assert_eq!(g, Duration::from_secs(30));
872 }
873
874 #[test]
877 fn credential_bearer_maps_to_grpc_bearer() {
878 let auth = credential_to_auth(Credential::Bearer("tok".into()));
879 assert!(matches!(auth, GrpcAuth::Bearer { token } if token == "tok"));
880 }
881
882 #[test]
883 fn credential_token_maps_to_metadata_authorization() {
884 let auth = credential_to_auth(Credential::Token("Custom xyz".into()));
885 match auth {
886 GrpcAuth::Metadata { entries } => {
887 assert_eq!(entries.len(), 1);
888 assert_eq!(entries[0].key, "authorization");
889 assert_eq!(entries[0].value, "Custom xyz");
890 }
891 other => panic!("expected Metadata, got {other:?}"),
892 }
893 }
894
895 #[test]
896 fn credential_header_maps_to_metadata_with_given_name() {
897 let auth = credential_to_auth(Credential::Header {
898 name: "x-api-key".into(),
899 value: "secret".into(),
900 });
901 match auth {
902 GrpcAuth::Metadata { entries } => {
903 assert_eq!(entries.len(), 1);
904 assert_eq!(entries[0].key, "x-api-key");
905 assert_eq!(entries[0].value, "secret");
906 }
907 other => panic!("expected Metadata, got {other:?}"),
908 }
909 }
910
911 #[test]
912 fn credential_basic_maps_to_base64_authorization_metadata() {
913 let auth = credential_to_auth(Credential::Basic {
914 username: "alice".into(),
915 password: "p@ss".into(),
916 });
917 match auth {
918 GrpcAuth::Metadata { entries } => {
919 assert_eq!(entries.len(), 1);
920 assert_eq!(entries[0].key, "authorization");
921 let expected = format!(
922 "Basic {}",
923 base64::engine::general_purpose::STANDARD.encode("alice:p@ss")
924 );
925 assert_eq!(entries[0].value, expected);
926 }
927 other => panic!("expected Metadata, got {other:?}"),
928 }
929 }
930
931 fn make_dummy_stream() -> GrpcStream {
936 use prost::Message;
937 let fds_set = prost_types::FileDescriptorSet {
938 file: vec![prost_types::FileDescriptorProto {
939 name: Some("dummy.proto".into()),
940 syntax: Some("proto3".into()),
941 ..Default::default()
942 }],
943 };
944 let bytes = fds_set.encode_to_vec();
945 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
946 std::fs::write(tmp.path(), &bytes).expect("write descriptor");
947 let config = GrpcStreamConfig::new(
948 "http://localhost:50051",
949 "dummy.Svc",
950 "Call",
951 tmp.path().to_str().unwrap(),
952 );
953 GrpcStream::new(config).expect("new from in-memory descriptor")
955 }
956
957 #[test]
958 fn rejects_zero_reconnect_initial_backoff() {
959 use prost::Message;
960 let fds_set = prost_types::FileDescriptorSet {
963 file: vec![prost_types::FileDescriptorProto {
964 name: Some("dummy.proto".into()),
965 syntax: Some("proto3".into()),
966 ..Default::default()
967 }],
968 };
969 let tmp = tempfile::NamedTempFile::new().expect("tempfile");
970 std::fs::write(tmp.path(), fds_set.encode_to_vec()).expect("write descriptor");
971 let mut config = GrpcStreamConfig::new(
972 "http://localhost:50051",
973 "dummy.Svc",
974 "Call",
975 tmp.path().to_str().unwrap(),
976 );
977 config.reconnect_initial_backoff = std::time::Duration::ZERO;
978 let Err(err) = GrpcStream::new(config) else {
979 panic!("a zero reconnect_initial_backoff must be rejected (it busy-spins)");
980 };
981 assert!(matches!(err, FaucetError::Config(_)), "{err:?}");
982 assert!(
983 err.to_string().contains("reconnect_initial_backoff"),
984 "{err}"
985 );
986 }
987
988 #[tokio::test]
991 async fn unresolved_auth_reference_errors_at_request_time() {
992 let mut stream = make_dummy_stream();
993 stream.config.auth = AuthSpec::Reference(AuthReference {
994 name: "missing-provider".into(),
995 });
996 let err = stream.build_grpc_request(vec![]).await.unwrap_err();
997 assert!(
998 matches!(err, FaucetError::Auth(_)),
999 "expected Auth error, got {err:?}"
1000 );
1001 let msg = err.to_string();
1002 assert!(
1003 msg.contains("missing-provider"),
1004 "error message should name the provider: {msg}"
1005 );
1006 }
1007
1008 #[derive(Debug)]
1011 struct FixedBearer(&'static str);
1012
1013 #[async_trait::async_trait]
1014 impl AuthProvider for FixedBearer {
1015 async fn credential(&self) -> Result<Credential, FaucetError> {
1016 Ok(Credential::Bearer(self.0.to_string()))
1017 }
1018 fn provider_name(&self) -> &'static str {
1019 "fixed-bearer"
1020 }
1021 }
1022
1023 #[tokio::test]
1024 async fn injected_provider_overrides_inline_none() {
1025 let provider: SharedAuthProvider = Arc::new(FixedBearer("MYTOKEN"));
1026 let stream = make_dummy_stream().with_auth_provider(provider);
1027 let req = stream
1029 .build_grpc_request(vec![])
1030 .await
1031 .expect("build request");
1032 let auth_header = req
1033 .metadata()
1034 .get("authorization")
1035 .expect("authorization metadata must be present")
1036 .to_str()
1037 .expect("ascii");
1038 assert_eq!(auth_header, "Bearer MYTOKEN");
1039 }
1040}