1#[allow(unused_imports)]
47use crate::flagd::evaluation::v1::EventStreamRequest;
48use crate::flagd::evaluation::v1::{
49 service_client::ServiceClient, ResolveBooleanRequest, ResolveBooleanResponse,
50 ResolveFloatRequest, ResolveFloatResponse, ResolveIntRequest, ResolveIntResponse,
51 ResolveObjectRequest, ResolveObjectResponse, ResolveStringRequest, ResolveStringResponse,
52};
53use crate::{convert_context, convert_proto_struct_to_struct_value, FlagdOptions};
54use async_trait::async_trait;
55use hyper_util::rt::TokioIo;
56use open_feature::provider::{FeatureProvider, ProviderMetadata, ResolutionDetails};
57use open_feature::{
58 EvaluationContext, EvaluationError, EvaluationErrorCode, EvaluationReason, FlagMetadata,
59 FlagMetadataValue, StructValue,
60};
61use std::collections::HashMap;
62use std::sync::OnceLock;
63use std::time::Duration;
64use tokio::net::UnixStream;
65use tokio::time::sleep;
66use tonic::transport::{Channel, Endpoint, Uri};
67use tower::service_fn;
68use tracing::{debug, error, instrument, warn};
69
70use super::common::upstream::UpstreamConfig;
71
72type ClientType = ServiceClient<Channel>;
73
74fn convert_proto_metadata(metadata: prost_types::Struct) -> FlagMetadata {
75 let mut values = HashMap::new();
76 for (k, v) in metadata.fields {
77 let metadata_value = match v.kind.unwrap() {
78 prost_types::value::Kind::BoolValue(b) => FlagMetadataValue::Bool(b),
79 prost_types::value::Kind::NumberValue(n) => FlagMetadataValue::Float(n),
80 prost_types::value::Kind::StringValue(s) => FlagMetadataValue::String(s),
81 _ => FlagMetadataValue::String("unsupported".to_string()),
82 };
83 values.insert(k, metadata_value);
84 }
85 FlagMetadata { values }
86}
87
88pub struct RpcResolver {
89 client: ClientType,
90 metadata: OnceLock<ProviderMetadata>,
91}
92
93impl RpcResolver {
94 #[instrument(skip(options))]
95 pub async fn new(
96 options: &FlagdOptions,
97 ) -> Result<Self, Box<dyn std::error::Error + Send + Sync>> {
98 debug!("initializing RPC resolver connection to {}", options.host);
99
100 let mut retry_delay = Duration::from_millis(options.retry_backoff_ms as u64);
101 let mut attempts = 0;
102
103 loop {
104 match RpcResolver::establish_connection(options).await {
105 Ok(client) => {
106 debug!("Successfully established RPC connection");
107 return Ok(Self {
108 client,
109 metadata: OnceLock::new(),
110 });
111 }
112 Err(e) => {
113 attempts += 1;
114 if attempts >= options.retry_grace_period {
115 error!("Failed to establish connection after {} attempts", attempts);
116 return Err(e);
117 }
118
119 warn!(
120 "Connection attempt {} failed, retrying in {}ms: {}",
121 attempts,
122 retry_delay.as_millis(),
123 e
124 );
125
126 sleep(retry_delay).await;
127 retry_delay = Duration::from_millis((retry_delay.as_millis() * 2) as u64)
128 .min(Duration::from_millis(options.retry_backoff_max_ms as u64));
129 }
130 }
131 }
132 }
133
134 async fn establish_connection(
135 options: &FlagdOptions,
136 ) -> Result<ClientType, Box<dyn std::error::Error + Send + Sync>> {
137 if let Some(socket_path) = &options.socket_path {
138 debug!("Attempting Unix socket connection to: {}", socket_path);
139 let socket_path = socket_path.clone();
140 let channel = Endpoint::try_from("http://[::]:50051")?
141 .connect_with_connector(service_fn(move |_: Uri| {
142 let path = socket_path.clone();
143 async move {
144 let stream = UnixStream::connect(path).await?;
145 Ok::<_, std::io::Error>(TokioIo::new(stream))
146 }
147 }))
148 .await?;
149
150 return Ok(ServiceClient::new(channel));
151 }
152
153 let target = options
154 .target_uri
155 .clone()
156 .unwrap_or_else(|| format!("{}:{}", options.host, options.port));
157 let upstream_config = UpstreamConfig::new(target.replace("http://", ""), false)?;
158 let mut endpoint = upstream_config.endpoint().clone();
159
160 if let Some(uri) = &options.target_uri {
162 if uri.starts_with("envoy://") {
163 let without_prefix = uri.trim_start_matches("envoy://");
165 let segments: Vec<&str> = without_prefix.split('/').collect();
166 if segments.len() >= 2 {
167 let authority_str = segments[1];
168 let authority_uri =
170 std::str::FromStr::from_str(&format!("http://{}", authority_str))?;
171 endpoint = endpoint.origin(authority_uri);
172 }
173 }
174 }
175
176 let channel = endpoint
177 .timeout(Duration::from_millis(options.deadline_ms as u64))
178 .connect()
179 .await?;
180
181 Ok(ServiceClient::new(channel))
182 }
183}
184
185#[async_trait]
186impl FeatureProvider for RpcResolver {
187 fn metadata(&self) -> &ProviderMetadata {
188 self.metadata.get_or_init(|| ProviderMetadata::new("flagd"))
189 }
190
191 #[instrument(skip(self, context))]
192 async fn resolve_bool_value(
193 &self,
194 flag_key: &str,
195 context: &EvaluationContext,
196 ) -> Result<ResolutionDetails<bool>, EvaluationError> {
197 debug!(flag_key, "resolving boolean flag");
198 let request = ResolveBooleanRequest {
199 flag_key: flag_key.to_string(),
200 context: convert_context(context),
201 };
202
203 match self.client.clone().resolve_boolean(request).await {
204 Ok(response) => {
205 let inner: ResolveBooleanResponse = response.into_inner();
206 debug!(flag_key, value = inner.value, reason = %inner.reason, "boolean flag resolved");
207 Ok(ResolutionDetails {
208 value: inner.value,
209 variant: Some(inner.variant),
210 reason: Some(EvaluationReason::Other(inner.reason)),
211 flag_metadata: inner.metadata.map(convert_proto_metadata),
212 })
213 }
214 Err(status) => {
215 error!(flag_key, error = %status, "failed to resolve boolean flag");
216 Err(EvaluationError {
217 code: EvaluationErrorCode::General(status.code().to_string()),
218 message: Some(status.message().to_string()),
219 })
220 }
221 }
222 }
223
224 #[instrument(skip(self, context))]
225 async fn resolve_string_value(
226 &self,
227 flag_key: &str,
228 context: &EvaluationContext,
229 ) -> Result<ResolutionDetails<String>, EvaluationError> {
230 debug!(flag_key, "resolving string flag");
231 let request = ResolveStringRequest {
232 flag_key: flag_key.to_string(),
233 context: convert_context(context),
234 };
235
236 match self.client.clone().resolve_string(request).await {
237 Ok(response) => {
238 let inner: ResolveStringResponse = response.into_inner();
239 debug!(flag_key, value = %inner.value, reason = %inner.reason, "string flag resolved");
240 Ok(ResolutionDetails {
241 value: inner.value,
242 variant: Some(inner.variant),
243 reason: Some(EvaluationReason::Other(inner.reason)),
244 flag_metadata: inner.metadata.map(convert_proto_metadata),
245 })
246 }
247 Err(status) => {
248 error!(flag_key, error = %status, "failed to resolve string flag");
249 Err(EvaluationError {
250 code: EvaluationErrorCode::General(status.code().to_string()),
251 message: Some(status.message().to_string()),
252 })
253 }
254 }
255 }
256
257 #[instrument(skip(self, context))]
258 async fn resolve_float_value(
259 &self,
260 flag_key: &str,
261 context: &EvaluationContext,
262 ) -> Result<ResolutionDetails<f64>, EvaluationError> {
263 debug!(flag_key, "resolving float flag");
264 let request = ResolveFloatRequest {
265 flag_key: flag_key.to_string(),
266 context: convert_context(context),
267 };
268
269 match self.client.clone().resolve_float(request).await {
270 Ok(response) => {
271 let inner: ResolveFloatResponse = response.into_inner();
272 debug!(flag_key, value = inner.value, reason = %inner.reason, "float flag resolved");
273 Ok(ResolutionDetails {
274 value: inner.value,
275 variant: Some(inner.variant),
276 reason: Some(EvaluationReason::Other(inner.reason)),
277 flag_metadata: inner.metadata.map(convert_proto_metadata),
278 })
279 }
280 Err(status) => {
281 error!(flag_key, error = %status, "failed to resolve float flag");
282 Err(EvaluationError {
283 code: EvaluationErrorCode::General(status.code().to_string()),
284 message: Some(status.message().to_string()),
285 })
286 }
287 }
288 }
289
290 #[instrument(skip(self, context))]
291 async fn resolve_int_value(
292 &self,
293 flag_key: &str,
294 context: &EvaluationContext,
295 ) -> Result<ResolutionDetails<i64>, EvaluationError> {
296 debug!(flag_key, "resolving integer flag");
297 let request = ResolveIntRequest {
298 flag_key: flag_key.to_string(),
299 context: convert_context(context),
300 };
301
302 match self.client.clone().resolve_int(request).await {
303 Ok(response) => {
304 let inner: ResolveIntResponse = response.into_inner();
305 debug!(flag_key, value = inner.value, reason = %inner.reason, "integer flag resolved");
306 Ok(ResolutionDetails {
307 value: inner.value,
308 variant: Some(inner.variant),
309 reason: Some(EvaluationReason::Other(inner.reason)),
310 flag_metadata: inner.metadata.map(convert_proto_metadata),
311 })
312 }
313 Err(status) => {
314 error!(flag_key, error = %status, "failed to resolve integer flag");
315 Err(EvaluationError {
316 code: EvaluationErrorCode::General(status.code().to_string()),
317 message: Some(status.message().to_string()),
318 })
319 }
320 }
321 }
322
323 #[instrument(skip(self, context))]
324 async fn resolve_struct_value(
325 &self,
326 flag_key: &str,
327 context: &EvaluationContext,
328 ) -> Result<ResolutionDetails<StructValue>, EvaluationError> {
329 debug!(flag_key, "resolving struct flag");
330 let request = ResolveObjectRequest {
331 flag_key: flag_key.to_string(),
332 context: convert_context(context),
333 };
334
335 match self.client.clone().resolve_object(request).await {
336 Ok(response) => {
337 let inner: ResolveObjectResponse = response.into_inner();
338 debug!(flag_key, reason = %inner.reason, "struct flag resolved");
339 Ok(ResolutionDetails {
340 value: convert_proto_struct_to_struct_value(inner.value.unwrap_or_default()),
341 variant: Some(inner.variant),
342 reason: Some(EvaluationReason::Other(inner.reason)),
343 flag_metadata: inner.metadata.map(convert_proto_metadata),
344 })
345 }
346 Err(status) => {
347 error!(flag_key, error = %status, "failed to resolve struct flag");
348 Err(EvaluationError {
349 code: EvaluationErrorCode::General(status.code().to_string()),
350 message: Some(status.message().to_string()),
351 })
352 }
353 }
354 }
355}
356
357#[cfg(test)]
358mod tests {
359 use super::*;
360 use crate::flagd::evaluation::v1::{
361 service_server::{Service, ServiceServer},
362 EventStreamResponse, ResolveAllRequest, ResolveAllResponse,
363 };
364 use futures_core::Stream;
365 use serial_test::serial;
366 use std::{collections::BTreeMap, pin::Pin};
367 use tempfile::TempDir;
368 use test_log::test;
369 use tokio::net::UnixListener;
370 use tokio::sync::oneshot;
371 use tokio::{net::TcpListener, time::Instant};
372 use tokio_stream::wrappers::UnixListenerStream;
373 use tonic::{transport::Server, Request, Response, Status};
374
375 pub struct MockFlagService;
376
377 #[tonic::async_trait]
378 impl Service for MockFlagService {
379 async fn resolve_boolean(
380 &self,
381 _request: Request<ResolveBooleanRequest>,
382 ) -> Result<Response<ResolveBooleanResponse>, Status> {
383 Ok(Response::new(ResolveBooleanResponse {
384 value: true,
385 reason: "test".to_string(),
386 variant: "test".to_string(),
387 metadata: Some(create_test_metadata()),
388 }))
389 }
390
391 async fn resolve_string(
392 &self,
393 _request: Request<ResolveStringRequest>,
394 ) -> Result<Response<ResolveStringResponse>, Status> {
395 Ok(Response::new(ResolveStringResponse {
396 value: "test".to_string(),
397 reason: "test".to_string(),
398 variant: "test".to_string(),
399 metadata: Some(create_test_metadata()),
400 }))
401 }
402
403 async fn resolve_float(
404 &self,
405 _request: Request<ResolveFloatRequest>,
406 ) -> Result<Response<ResolveFloatResponse>, Status> {
407 Ok(Response::new(ResolveFloatResponse {
408 value: 1.0,
409 reason: "test".to_string(),
410 variant: "test".to_string(),
411 metadata: Some(create_test_metadata()),
412 }))
413 }
414
415 async fn resolve_int(
416 &self,
417 _request: Request<ResolveIntRequest>,
418 ) -> Result<Response<ResolveIntResponse>, Status> {
419 Ok(Response::new(ResolveIntResponse {
420 value: 42,
421 reason: "test".to_string(),
422 variant: "test".to_string(),
423 metadata: Some(create_test_metadata()),
424 }))
425 }
426
427 async fn resolve_object(
428 &self,
429 _request: Request<ResolveObjectRequest>,
430 ) -> Result<Response<ResolveObjectResponse>, Status> {
431 let mut fields = BTreeMap::new();
432 fields.insert(
433 "key".to_string(),
434 prost_types::Value {
435 kind: Some(prost_types::value::Kind::StringValue("value".to_string())),
436 },
437 );
438
439 Ok(Response::new(ResolveObjectResponse {
440 value: Some(prost_types::Struct { fields }),
441 reason: "test".to_string(),
442 variant: "test".to_string(),
443 metadata: Some(create_test_metadata()),
444 }))
445 }
446
447 async fn resolve_all(
448 &self,
449 _request: Request<ResolveAllRequest>,
450 ) -> Result<Response<ResolveAllResponse>, Status> {
451 Ok(Response::new(ResolveAllResponse {
452 flags: Default::default(),
453 metadata: Some(create_test_metadata()),
454 }))
455 }
456
457 type EventStreamStream =
458 Pin<Box<dyn Stream<Item = Result<EventStreamResponse, Status>> + Send + 'static>>;
459
460 async fn event_stream(
461 &self,
462 _request: Request<EventStreamRequest>,
463 ) -> Result<Response<Self::EventStreamStream>, Status> {
464 let output = futures::stream::empty();
465 Ok(Response::new(Box::pin(output)))
466 }
467 }
468
469 fn create_test_metadata() -> prost_types::Struct {
470 let mut fields = BTreeMap::new();
471 fields.insert(
472 "bool_key".to_string(),
473 prost_types::Value {
474 kind: Some(prost_types::value::Kind::BoolValue(true)),
475 },
476 );
477 fields.insert(
478 "number_key".to_string(),
479 prost_types::Value {
480 kind: Some(prost_types::value::Kind::NumberValue(42.0)),
481 },
482 );
483 fields.insert(
484 "string_key".to_string(),
485 prost_types::Value {
486 kind: Some(prost_types::value::Kind::StringValue("test".to_string())),
487 },
488 );
489 prost_types::Struct { fields }
490 }
491
492 struct TestServer {
493 target: String,
494 _shutdown: oneshot::Sender<()>,
495 }
496
497 impl TestServer {
498 async fn new() -> Self {
499 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
500 let addr = listener.local_addr().unwrap();
501 let (tx, rx) = oneshot::channel();
502
503 let server = tonic::transport::Server::builder()
504 .add_service(ServiceServer::new(MockFlagService))
505 .serve(addr);
506
507 tokio::spawn(async move {
508 tokio::select! {
509 _ = server => {},
510 _ = rx => {},
511 }
512 });
513
514 Self {
515 target: format!("{}:{}", addr.ip(), addr.port()),
516 _shutdown: tx,
517 }
518 }
519 }
520
521 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
522 async fn test_dns_resolution() {
523 let server = TestServer::new().await;
524 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
526 let options = FlagdOptions {
527 host: server.target.clone(),
528 port: 8013,
529 target_uri: None,
530 deadline_ms: 500,
531 ..Default::default()
532 };
533 let resolver = RpcResolver::new(&options).await.unwrap();
534 let context = EvaluationContext::default().with_targeting_key("test-user");
535
536 let result = resolver
537 .resolve_bool_value("test-flag", &context)
538 .await
539 .unwrap();
540 assert_eq!(result.value, true);
541 }
542
543 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
544 async fn test_envoy_resolution() {
545 let server = TestServer::new().await;
546 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
547
548 let options = FlagdOptions {
549 host: server.target.clone(),
550 port: 8013,
551 target_uri: Some(format!("envoy://{}/flagd-service", server.target)),
552 deadline_ms: 500,
553 ..Default::default()
554 };
555
556 let resolver = RpcResolver::new(&options).await.unwrap();
557 let context = EvaluationContext::default().with_targeting_key("test-user");
558
559 let result = resolver
560 .resolve_bool_value("test-flag", &context)
561 .await
562 .unwrap();
563 assert_eq!(result.value, true);
564 }
565
566 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
567 async fn test_value_resolution() {
568 let server = TestServer::new().await;
569 let options = FlagdOptions {
570 host: server.target.clone(),
571 port: 8013,
572 target_uri: None,
573 deadline_ms: 500,
574 ..Default::default()
575 };
576 let resolver = RpcResolver::new(&options).await.unwrap();
577 let context = EvaluationContext::default().with_targeting_key("test-user");
578
579 assert_eq!(
581 resolver
582 .resolve_bool_value("test-flag", &context)
583 .await
584 .unwrap()
585 .value,
586 true
587 );
588 assert_eq!(
589 resolver
590 .resolve_string_value("test-flag", &context)
591 .await
592 .unwrap()
593 .value,
594 "test"
595 );
596 assert_eq!(
597 resolver
598 .resolve_float_value("test-flag", &context)
599 .await
600 .unwrap()
601 .value,
602 1.0
603 );
604 assert_eq!(
605 resolver
606 .resolve_int_value("test-flag", &context)
607 .await
608 .unwrap()
609 .value,
610 42
611 );
612
613 let struct_result = resolver
614 .resolve_struct_value("test-flag", &context)
615 .await
616 .unwrap();
617 assert!(!struct_result.value.fields.is_empty());
618 }
619
620 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
621 async fn test_metadata() {
622 let metadata = create_test_metadata();
623 let flag_metadata = convert_proto_metadata(metadata);
624
625 assert!(matches!(
626 flag_metadata.values.get("bool_key"),
627 Some(FlagMetadataValue::Bool(true))
628 ));
629 assert!(matches!(
630 flag_metadata.values.get("number_key"),
631 Some(FlagMetadataValue::Float(42.0))
632 ));
633 assert!(matches!(
634 flag_metadata.values.get("string_key"),
635 Some(FlagMetadataValue::String(s)) if s == "test"
636 ));
637 }
638
639 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
640 async fn test_standard_connection() {
641 let server = TestServer::new().await;
642 let parts: Vec<&str> = server.target.split(':').collect();
643 let options = FlagdOptions {
644 host: parts[0].to_string(),
645 port: parts[1].parse().unwrap(),
646 target_uri: None,
647 deadline_ms: 500,
648 ..Default::default()
649 };
650
651 let resolver = RpcResolver::new(&options).await.unwrap();
652 let context = EvaluationContext::default().with_targeting_key("test-user");
653
654 let result = resolver
655 .resolve_bool_value("test-flag", &context)
656 .await
657 .unwrap();
658 assert_eq!(result.value, true);
659 }
660
661 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
662 async fn test_envoy_connection() {
663 let server = TestServer::new().await;
664 let parts: Vec<&str> = server.target.split(':').collect();
665 let options = FlagdOptions {
666 host: parts[0].to_string(),
667 port: parts[1].parse().unwrap(),
668 target_uri: Some(format!("envoy://{}/flagd-service", server.target)),
669 deadline_ms: 500,
670 ..Default::default()
671 };
672
673 let resolver = RpcResolver::new(&options).await.unwrap();
674 let context = EvaluationContext::default().with_targeting_key("test-user");
675
676 let result = resolver
677 .resolve_bool_value("test-flag", &context)
678 .await
679 .unwrap();
680 assert_eq!(result.value, true);
681 }
682
683 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
684 #[serial]
685 async fn test_retry_mechanism() {
686 let options = FlagdOptions {
687 host: "invalid-host".to_string(),
688 port: 8013,
689 retry_backoff_ms: 100,
690 retry_backoff_max_ms: 400,
691 retry_grace_period: 3,
692 ..Default::default()
693 };
694
695 let start = Instant::now();
696 let result = RpcResolver::new(&options).await;
697 let duration = start.elapsed();
698
699 assert!(result.is_err());
700 assert!(duration.as_millis() >= 300);
702 assert!(duration.as_millis() < 600);
704 }
705
706 #[test(tokio::test)]
707 async fn test_successful_retry() {
708 let server = TestServer::new().await;
709 let options = FlagdOptions {
710 host: server.target.clone(),
711 port: 8013,
712 retry_backoff_ms: 100,
713 retry_backoff_max_ms: 400,
714 retry_grace_period: 3,
715 ..Default::default()
716 };
717
718 let resolver = RpcResolver::new(&options).await.unwrap();
719 let context = EvaluationContext::default();
720
721 let result = resolver
722 .resolve_bool_value("test-flag", &context)
723 .await
724 .unwrap();
725 assert_eq!(result.value, true);
726 }
727
728 #[test(tokio::test)]
729 async fn test_rpc_unix_socket_connection() {
730 let tmp_dir = TempDir::new().unwrap();
731 let socket_path = tmp_dir.path().join("test.sock");
732 let socket_path_str = socket_path.to_str().unwrap().to_string();
733
734 let server_handle = tokio::spawn(async move {
736 let uds = UnixListener::bind(&socket_path).unwrap();
737 Server::builder()
738 .add_service(ServiceServer::new(MockFlagService))
739 .serve_with_incoming(UnixListenerStream::new(uds))
740 .await
741 .unwrap();
742 });
743
744 tokio::time::sleep(Duration::from_millis(100)).await;
746
747 let options = FlagdOptions {
748 socket_path: Some(socket_path_str),
749 retry_backoff_ms: 100,
750 retry_backoff_max_ms: 400,
751 retry_grace_period: 3,
752 ..Default::default()
753 };
754
755 let resolver = RpcResolver::new(&options).await;
756 assert!(resolver.is_ok());
757
758 server_handle.abort();
760 }
761}