1#[allow(unused_imports)]
47use crate::flagd::evaluation::v1::EventStreamRequest;
48use crate::flagd::evaluation::v1::{
49 service_client::ServiceClient, ResolveBooleanRequest, ResolveBooleanResponse,
50 ResolveFloatRequest, ResolveFloatResponse, ResolveIntRequest, ResolveIntResponse,
51 ResolveObjectRequest, ResolveObjectResponse, ResolveStringRequest, ResolveStringResponse,
52};
53use crate::{convert_context, convert_proto_struct_to_struct_value, FlagdOptions};
54use async_trait::async_trait;
55use hyper_util::rt::TokioIo;
56use open_feature::provider::{FeatureProvider, ProviderMetadata, ResolutionDetails};
57use open_feature::{
58 EvaluationContext, EvaluationError, EvaluationErrorCode, EvaluationReason, FlagMetadata,
59 FlagMetadataValue, StructValue,
60};
61use std::collections::HashMap;
62use std::sync::OnceLock;
63use std::time::Duration;
64use tokio::net::UnixStream;
65use tokio::time::sleep;
66use tonic::transport::{Channel, Endpoint, Uri};
67use tower::service_fn;
68use tracing::{debug, error, instrument, warn};
69
70use super::common::upstream::UpstreamConfig;
71
72type ClientType = ServiceClient<Channel>;
73
74fn convert_proto_metadata(metadata: prost_types::Struct) -> FlagMetadata {
75 let mut values = HashMap::new();
76 for (k, v) in metadata.fields {
77 let metadata_value = match v.kind.unwrap() {
78 prost_types::value::Kind::BoolValue(b) => FlagMetadataValue::Bool(b),
79 prost_types::value::Kind::NumberValue(n) => FlagMetadataValue::Float(n),
80 prost_types::value::Kind::StringValue(s) => FlagMetadataValue::String(s),
81 _ => FlagMetadataValue::String("unsupported".to_string()),
82 };
83 values.insert(k, metadata_value);
84 }
85 FlagMetadata { values }
86}
87
88pub struct RpcResolver {
89 client: ClientType,
90 metadata: OnceLock<ProviderMetadata>,
91}
92
93impl RpcResolver {
94 #[instrument(skip(options))]
95 pub async fn new(options: &FlagdOptions) -> Result<Self, Box<dyn std::error::Error>> {
96 debug!("initializing RPC resolver connection to {}", options.host);
97
98 let mut retry_delay = Duration::from_millis(options.retry_backoff_ms as u64);
99 let mut attempts = 0;
100
101 loop {
102 match RpcResolver::establish_connection(options).await {
103 Ok(client) => {
104 debug!("Successfully established RPC connection");
105 return Ok(Self {
106 client,
107 metadata: OnceLock::new(),
108 });
109 }
110 Err(e) => {
111 attempts += 1;
112 if attempts >= options.retry_grace_period {
113 error!("Failed to establish connection after {} attempts", attempts);
114 return Err(e);
115 }
116
117 warn!(
118 "Connection attempt {} failed, retrying in {}ms: {}",
119 attempts,
120 retry_delay.as_millis(),
121 e
122 );
123
124 sleep(retry_delay).await;
125 retry_delay = Duration::from_millis((retry_delay.as_millis() * 2) as u64)
126 .min(Duration::from_millis(options.retry_backoff_max_ms as u64));
127 }
128 }
129 }
130 }
131
132 async fn establish_connection(
133 options: &FlagdOptions,
134 ) -> Result<ClientType, Box<dyn std::error::Error>> {
135 if let Some(socket_path) = &options.socket_path {
136 debug!("Attempting Unix socket connection to: {}", socket_path);
137 let socket_path = socket_path.clone();
138 let channel = Endpoint::try_from("http://[::]:50051")?
139 .connect_with_connector(service_fn(move |_: Uri| {
140 let path = socket_path.clone();
141 async move {
142 let stream = UnixStream::connect(path).await?;
143 Ok::<_, std::io::Error>(TokioIo::new(stream))
144 }
145 }))
146 .await?;
147
148 return Ok(ServiceClient::new(channel));
149 }
150
151 let target = options
152 .target_uri
153 .clone()
154 .unwrap_or_else(|| format!("{}:{}", options.host, options.port));
155 let upstream_config = UpstreamConfig::new(target.replace("http://", ""), false)?;
156 let mut endpoint = upstream_config.endpoint().clone();
157
158 if let Some(uri) = &options.target_uri {
160 if uri.starts_with("envoy://") {
161 let without_prefix = uri.trim_start_matches("envoy://");
163 let segments: Vec<&str> = without_prefix.split('/').collect();
164 if segments.len() >= 2 {
165 let authority_str = segments[1];
166 let authority_uri =
168 std::str::FromStr::from_str(&format!("http://{}", authority_str))?;
169 endpoint = endpoint.origin(authority_uri);
170 }
171 }
172 }
173
174 let channel = endpoint
175 .timeout(Duration::from_millis(options.deadline_ms as u64))
176 .connect()
177 .await?;
178
179 Ok(ServiceClient::new(channel))
180 }
181}
182
183#[async_trait]
184impl FeatureProvider for RpcResolver {
185 fn metadata(&self) -> &ProviderMetadata {
186 self.metadata.get_or_init(|| ProviderMetadata::new("flagd"))
187 }
188
189 #[instrument(skip(self, context))]
190 async fn resolve_bool_value(
191 &self,
192 flag_key: &str,
193 context: &EvaluationContext,
194 ) -> Result<ResolutionDetails<bool>, EvaluationError> {
195 debug!(flag_key, "resolving boolean flag");
196 let request = ResolveBooleanRequest {
197 flag_key: flag_key.to_string(),
198 context: convert_context(context),
199 };
200
201 match self.client.clone().resolve_boolean(request).await {
202 Ok(response) => {
203 let inner: ResolveBooleanResponse = response.into_inner();
204 debug!(flag_key, value = inner.value, reason = %inner.reason, "boolean flag resolved");
205 Ok(ResolutionDetails {
206 value: inner.value,
207 variant: Some(inner.variant),
208 reason: Some(EvaluationReason::Other(inner.reason)),
209 flag_metadata: inner.metadata.map(convert_proto_metadata),
210 })
211 }
212 Err(status) => {
213 error!(flag_key, error = %status, "failed to resolve boolean flag");
214 Err(EvaluationError {
215 code: EvaluationErrorCode::General(status.code().to_string()),
216 message: Some(status.message().to_string()),
217 })
218 }
219 }
220 }
221
222 #[instrument(skip(self, context))]
223 async fn resolve_string_value(
224 &self,
225 flag_key: &str,
226 context: &EvaluationContext,
227 ) -> Result<ResolutionDetails<String>, EvaluationError> {
228 debug!(flag_key, "resolving string flag");
229 let request = ResolveStringRequest {
230 flag_key: flag_key.to_string(),
231 context: convert_context(context),
232 };
233
234 match self.client.clone().resolve_string(request).await {
235 Ok(response) => {
236 let inner: ResolveStringResponse = response.into_inner();
237 debug!(flag_key, value = %inner.value, reason = %inner.reason, "string flag resolved");
238 Ok(ResolutionDetails {
239 value: inner.value,
240 variant: Some(inner.variant),
241 reason: Some(EvaluationReason::Other(inner.reason)),
242 flag_metadata: inner.metadata.map(convert_proto_metadata),
243 })
244 }
245 Err(status) => {
246 error!(flag_key, error = %status, "failed to resolve string flag");
247 Err(EvaluationError {
248 code: EvaluationErrorCode::General(status.code().to_string()),
249 message: Some(status.message().to_string()),
250 })
251 }
252 }
253 }
254
255 #[instrument(skip(self, context))]
256 async fn resolve_float_value(
257 &self,
258 flag_key: &str,
259 context: &EvaluationContext,
260 ) -> Result<ResolutionDetails<f64>, EvaluationError> {
261 debug!(flag_key, "resolving float flag");
262 let request = ResolveFloatRequest {
263 flag_key: flag_key.to_string(),
264 context: convert_context(context),
265 };
266
267 match self.client.clone().resolve_float(request).await {
268 Ok(response) => {
269 let inner: ResolveFloatResponse = response.into_inner();
270 debug!(flag_key, value = inner.value, reason = %inner.reason, "float flag resolved");
271 Ok(ResolutionDetails {
272 value: inner.value,
273 variant: Some(inner.variant),
274 reason: Some(EvaluationReason::Other(inner.reason)),
275 flag_metadata: inner.metadata.map(convert_proto_metadata),
276 })
277 }
278 Err(status) => {
279 error!(flag_key, error = %status, "failed to resolve float flag");
280 Err(EvaluationError {
281 code: EvaluationErrorCode::General(status.code().to_string()),
282 message: Some(status.message().to_string()),
283 })
284 }
285 }
286 }
287
288 #[instrument(skip(self, context))]
289 async fn resolve_int_value(
290 &self,
291 flag_key: &str,
292 context: &EvaluationContext,
293 ) -> Result<ResolutionDetails<i64>, EvaluationError> {
294 debug!(flag_key, "resolving integer flag");
295 let request = ResolveIntRequest {
296 flag_key: flag_key.to_string(),
297 context: convert_context(context),
298 };
299
300 match self.client.clone().resolve_int(request).await {
301 Ok(response) => {
302 let inner: ResolveIntResponse = response.into_inner();
303 debug!(flag_key, value = inner.value, reason = %inner.reason, "integer flag resolved");
304 Ok(ResolutionDetails {
305 value: inner.value,
306 variant: Some(inner.variant),
307 reason: Some(EvaluationReason::Other(inner.reason)),
308 flag_metadata: inner.metadata.map(convert_proto_metadata),
309 })
310 }
311 Err(status) => {
312 error!(flag_key, error = %status, "failed to resolve integer flag");
313 Err(EvaluationError {
314 code: EvaluationErrorCode::General(status.code().to_string()),
315 message: Some(status.message().to_string()),
316 })
317 }
318 }
319 }
320
321 #[instrument(skip(self, context))]
322 async fn resolve_struct_value(
323 &self,
324 flag_key: &str,
325 context: &EvaluationContext,
326 ) -> Result<ResolutionDetails<StructValue>, EvaluationError> {
327 debug!(flag_key, "resolving struct flag");
328 let request = ResolveObjectRequest {
329 flag_key: flag_key.to_string(),
330 context: convert_context(context),
331 };
332
333 match self.client.clone().resolve_object(request).await {
334 Ok(response) => {
335 let inner: ResolveObjectResponse = response.into_inner();
336 debug!(flag_key, reason = %inner.reason, "struct flag resolved");
337 Ok(ResolutionDetails {
338 value: convert_proto_struct_to_struct_value(inner.value.unwrap_or_default()),
339 variant: Some(inner.variant),
340 reason: Some(EvaluationReason::Other(inner.reason)),
341 flag_metadata: inner.metadata.map(convert_proto_metadata),
342 })
343 }
344 Err(status) => {
345 error!(flag_key, error = %status, "failed to resolve struct flag");
346 Err(EvaluationError {
347 code: EvaluationErrorCode::General(status.code().to_string()),
348 message: Some(status.message().to_string()),
349 })
350 }
351 }
352 }
353}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use crate::flagd::evaluation::v1::{
359 service_server::{Service, ServiceServer},
360 EventStreamResponse, ResolveAllRequest, ResolveAllResponse,
361 };
362 use futures_core::Stream;
363 use serial_test::serial;
364 use std::{collections::BTreeMap, pin::Pin};
365 use tempfile::TempDir;
366 use test_log::test;
367 use tokio::net::UnixListener;
368 use tokio::sync::oneshot;
369 use tokio::{net::TcpListener, time::Instant};
370 use tokio_stream::wrappers::UnixListenerStream;
371 use tonic::{transport::Server, Request, Response, Status};
372
373 pub struct MockFlagService;
374
375 #[tonic::async_trait]
376 impl Service for MockFlagService {
377 async fn resolve_boolean(
378 &self,
379 _request: Request<ResolveBooleanRequest>,
380 ) -> Result<Response<ResolveBooleanResponse>, Status> {
381 Ok(Response::new(ResolveBooleanResponse {
382 value: true,
383 reason: "test".to_string(),
384 variant: "test".to_string(),
385 metadata: Some(create_test_metadata()),
386 }))
387 }
388
389 async fn resolve_string(
390 &self,
391 _request: Request<ResolveStringRequest>,
392 ) -> Result<Response<ResolveStringResponse>, Status> {
393 Ok(Response::new(ResolveStringResponse {
394 value: "test".to_string(),
395 reason: "test".to_string(),
396 variant: "test".to_string(),
397 metadata: Some(create_test_metadata()),
398 }))
399 }
400
401 async fn resolve_float(
402 &self,
403 _request: Request<ResolveFloatRequest>,
404 ) -> Result<Response<ResolveFloatResponse>, Status> {
405 Ok(Response::new(ResolveFloatResponse {
406 value: 1.0,
407 reason: "test".to_string(),
408 variant: "test".to_string(),
409 metadata: Some(create_test_metadata()),
410 }))
411 }
412
413 async fn resolve_int(
414 &self,
415 _request: Request<ResolveIntRequest>,
416 ) -> Result<Response<ResolveIntResponse>, Status> {
417 Ok(Response::new(ResolveIntResponse {
418 value: 42,
419 reason: "test".to_string(),
420 variant: "test".to_string(),
421 metadata: Some(create_test_metadata()),
422 }))
423 }
424
425 async fn resolve_object(
426 &self,
427 _request: Request<ResolveObjectRequest>,
428 ) -> Result<Response<ResolveObjectResponse>, Status> {
429 let mut fields = BTreeMap::new();
430 fields.insert(
431 "key".to_string(),
432 prost_types::Value {
433 kind: Some(prost_types::value::Kind::StringValue("value".to_string())),
434 },
435 );
436
437 Ok(Response::new(ResolveObjectResponse {
438 value: Some(prost_types::Struct { fields }),
439 reason: "test".to_string(),
440 variant: "test".to_string(),
441 metadata: Some(create_test_metadata()),
442 }))
443 }
444
445 async fn resolve_all(
446 &self,
447 _request: Request<ResolveAllRequest>,
448 ) -> Result<Response<ResolveAllResponse>, Status> {
449 Ok(Response::new(ResolveAllResponse {
450 flags: Default::default(),
451 metadata: Some(create_test_metadata()),
452 }))
453 }
454
455 type EventStreamStream =
456 Pin<Box<dyn Stream<Item = Result<EventStreamResponse, Status>> + Send + 'static>>;
457
458 async fn event_stream(
459 &self,
460 _request: Request<EventStreamRequest>,
461 ) -> Result<Response<Self::EventStreamStream>, Status> {
462 let output = futures::stream::empty();
463 Ok(Response::new(Box::pin(output)))
464 }
465 }
466
467 fn create_test_metadata() -> prost_types::Struct {
468 let mut fields = BTreeMap::new();
469 fields.insert(
470 "bool_key".to_string(),
471 prost_types::Value {
472 kind: Some(prost_types::value::Kind::BoolValue(true)),
473 },
474 );
475 fields.insert(
476 "number_key".to_string(),
477 prost_types::Value {
478 kind: Some(prost_types::value::Kind::NumberValue(42.0)),
479 },
480 );
481 fields.insert(
482 "string_key".to_string(),
483 prost_types::Value {
484 kind: Some(prost_types::value::Kind::StringValue("test".to_string())),
485 },
486 );
487 prost_types::Struct { fields }
488 }
489
490 struct TestServer {
491 target: String,
492 _shutdown: oneshot::Sender<()>,
493 }
494
495 impl TestServer {
496 async fn new() -> Self {
497 let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
498 let addr = listener.local_addr().unwrap();
499 let (tx, rx) = oneshot::channel();
500
501 let server = tonic::transport::Server::builder()
502 .add_service(ServiceServer::new(MockFlagService))
503 .serve(addr);
504
505 tokio::spawn(async move {
506 tokio::select! {
507 _ = server => {},
508 _ = rx => {},
509 }
510 });
511
512 Self {
513 target: format!("{}:{}", addr.ip(), addr.port()),
514 _shutdown: tx,
515 }
516 }
517 }
518
519 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
520 async fn test_dns_resolution() {
521 let server = TestServer::new().await;
522 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
524 let options = FlagdOptions {
525 host: server.target.clone(),
526 port: 8013,
527 target_uri: None,
528 deadline_ms: 500,
529 ..Default::default()
530 };
531 let resolver = RpcResolver::new(&options).await.unwrap();
532 let context = EvaluationContext::default().with_targeting_key("test-user");
533
534 let result = resolver
535 .resolve_bool_value("test-flag", &context)
536 .await
537 .unwrap();
538 assert_eq!(result.value, true);
539 }
540
541 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
542 async fn test_envoy_resolution() {
543 let server = TestServer::new().await;
544 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
545
546 let options = FlagdOptions {
547 host: server.target.clone(),
548 port: 8013,
549 target_uri: Some(format!("envoy://{}/flagd-service", server.target)),
550 deadline_ms: 500,
551 ..Default::default()
552 };
553
554 let resolver = RpcResolver::new(&options).await.unwrap();
555 let context = EvaluationContext::default().with_targeting_key("test-user");
556
557 let result = resolver
558 .resolve_bool_value("test-flag", &context)
559 .await
560 .unwrap();
561 assert_eq!(result.value, true);
562 }
563
564 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
565 async fn test_value_resolution() {
566 let server = TestServer::new().await;
567 let options = FlagdOptions {
568 host: server.target.clone(),
569 port: 8013,
570 target_uri: None,
571 deadline_ms: 500,
572 ..Default::default()
573 };
574 let resolver = RpcResolver::new(&options).await.unwrap();
575 let context = EvaluationContext::default().with_targeting_key("test-user");
576
577 assert_eq!(
579 resolver
580 .resolve_bool_value("test-flag", &context)
581 .await
582 .unwrap()
583 .value,
584 true
585 );
586 assert_eq!(
587 resolver
588 .resolve_string_value("test-flag", &context)
589 .await
590 .unwrap()
591 .value,
592 "test"
593 );
594 assert_eq!(
595 resolver
596 .resolve_float_value("test-flag", &context)
597 .await
598 .unwrap()
599 .value,
600 1.0
601 );
602 assert_eq!(
603 resolver
604 .resolve_int_value("test-flag", &context)
605 .await
606 .unwrap()
607 .value,
608 42
609 );
610
611 let struct_result = resolver
612 .resolve_struct_value("test-flag", &context)
613 .await
614 .unwrap();
615 assert!(!struct_result.value.fields.is_empty());
616 }
617
618 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
619 async fn test_metadata() {
620 let metadata = create_test_metadata();
621 let flag_metadata = convert_proto_metadata(metadata);
622
623 assert!(matches!(
624 flag_metadata.values.get("bool_key"),
625 Some(FlagMetadataValue::Bool(true))
626 ));
627 assert!(matches!(
628 flag_metadata.values.get("number_key"),
629 Some(FlagMetadataValue::Float(42.0))
630 ));
631 assert!(matches!(
632 flag_metadata.values.get("string_key"),
633 Some(FlagMetadataValue::String(s)) if s == "test"
634 ));
635 }
636
637 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
638 async fn test_standard_connection() {
639 let server = TestServer::new().await;
640 let parts: Vec<&str> = server.target.split(':').collect();
641 let options = FlagdOptions {
642 host: parts[0].to_string(),
643 port: parts[1].parse().unwrap(),
644 target_uri: None,
645 deadline_ms: 500,
646 ..Default::default()
647 };
648
649 let resolver = RpcResolver::new(&options).await.unwrap();
650 let context = EvaluationContext::default().with_targeting_key("test-user");
651
652 let result = resolver
653 .resolve_bool_value("test-flag", &context)
654 .await
655 .unwrap();
656 assert_eq!(result.value, true);
657 }
658
659 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
660 async fn test_envoy_connection() {
661 let server = TestServer::new().await;
662 let parts: Vec<&str> = server.target.split(':').collect();
663 let options = FlagdOptions {
664 host: parts[0].to_string(),
665 port: parts[1].parse().unwrap(),
666 target_uri: Some(format!("envoy://{}/flagd-service", server.target)),
667 deadline_ms: 500,
668 ..Default::default()
669 };
670
671 let resolver = RpcResolver::new(&options).await.unwrap();
672 let context = EvaluationContext::default().with_targeting_key("test-user");
673
674 let result = resolver
675 .resolve_bool_value("test-flag", &context)
676 .await
677 .unwrap();
678 assert_eq!(result.value, true);
679 }
680
681 #[test(tokio::test(flavor = "multi_thread", worker_threads = 1))]
682 #[serial]
683 async fn test_retry_mechanism() {
684 let options = FlagdOptions {
685 host: "invalid-host".to_string(),
686 port: 8013,
687 retry_backoff_ms: 100,
688 retry_backoff_max_ms: 400,
689 retry_grace_period: 3,
690 ..Default::default()
691 };
692
693 let start = Instant::now();
694 let result = RpcResolver::new(&options).await;
695 let duration = start.elapsed();
696
697 assert!(result.is_err());
698 assert!(duration.as_millis() >= 300);
700 assert!(duration.as_millis() < 600);
702 }
703
704 #[test(tokio::test)]
705 async fn test_successful_retry() {
706 let server = TestServer::new().await;
707 let options = FlagdOptions {
708 host: server.target.clone(),
709 port: 8013,
710 retry_backoff_ms: 100,
711 retry_backoff_max_ms: 400,
712 retry_grace_period: 3,
713 ..Default::default()
714 };
715
716 let resolver = RpcResolver::new(&options).await.unwrap();
717 let context = EvaluationContext::default();
718
719 let result = resolver
720 .resolve_bool_value("test-flag", &context)
721 .await
722 .unwrap();
723 assert_eq!(result.value, true);
724 }
725
726 #[test(tokio::test)]
727 async fn test_rpc_unix_socket_connection() {
728 let tmp_dir = TempDir::new().unwrap();
729 let socket_path = tmp_dir.path().join("test.sock");
730 let socket_path_str = socket_path.to_str().unwrap().to_string();
731
732 let server_handle = tokio::spawn(async move {
734 let uds = UnixListener::bind(&socket_path).unwrap();
735 Server::builder()
736 .add_service(ServiceServer::new(MockFlagService))
737 .serve_with_incoming(UnixListenerStream::new(uds))
738 .await
739 .unwrap();
740 });
741
742 tokio::time::sleep(Duration::from_millis(100)).await;
744
745 let options = FlagdOptions {
746 socket_path: Some(socket_path_str),
747 retry_backoff_ms: 100,
748 retry_backoff_max_ms: 400,
749 retry_grace_period: 3,
750 ..Default::default()
751 };
752
753 let resolver = RpcResolver::new(&options).await;
754 assert!(resolver.is_ok());
755
756 server_handle.abort();
758 }
759}