Skip to main content

polars_io/cloud/
credential_provider.rs

1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use async_trait::async_trait;
9#[cfg(feature = "aws")]
10pub use object_store::aws::AwsCredential;
11#[cfg(feature = "azure")]
12pub use object_store::azure::AzureCredential;
13#[cfg(feature = "gcp")]
14pub use object_store::gcp::GcpCredential;
15use polars_core::config;
16use polars_error::{PolarsResult, polars_bail};
17use polars_utils::pl_str::PlSmallStr;
18#[cfg(feature = "python")]
19use polars_utils::python_function::PythonObject;
20#[cfg(feature = "python")]
21use python_impl::PythonCredentialProvider;
22
23#[derive(Clone, Debug, PartialEq, Hash, Eq)]
24pub enum PlCredentialProvider {
25    /// Prefer using [`PlCredentialProvider::from_func`] instead of constructing this directly
26    Function(CredentialProviderFunction),
27    #[cfg(feature = "python")]
28    Python(PythonCredentialProvider),
29}
30
31impl PlCredentialProvider {
32    /// Accepts a function that returns (credential, expiry time as seconds since UNIX_EPOCH)
33    ///
34    /// This functionality is unstable.
35    pub fn from_func(
36        // Internal notes
37        // * This function is exposed as the Rust API for `PlCredentialProvider`
38        func: impl Fn() -> Pin<
39            Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
40        > + Send
41        + Sync
42        + 'static,
43    ) -> Self {
44        Self::Function(CredentialProviderFunction(Arc::new(func)))
45    }
46
47    /// Intended to be called with an internal `CredentialProviderBuilder` from
48    /// py-polars.
49    #[cfg(feature = "python")]
50    pub fn from_python_builder(func: pyo3::Py<pyo3::PyAny>) -> Self {
51        Self::Python(python_impl::PythonCredentialProvider::Builder(Arc::new(
52            PythonObject(func),
53        )))
54    }
55
56    #[allow(unused)]
57    pub(super) fn func_addr(&self) -> usize {
58        match self {
59            Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize,
60            #[cfg(feature = "python")]
61            Self::Python(v) => v.func_addr(),
62        }
63    }
64
65    /// Python passes a `CredentialProviderBuilder`, this calls the builder to build the final
66    /// credential provider.
67    ///
68    /// This returns `Option` as the auto-initialization case is fallible and falls back to None.
69    pub(crate) fn try_into_initialized(
70        self,
71        clear_cached_credentials: bool,
72    ) -> PolarsResult<Option<Self>> {
73        match self {
74            Self::Function(_) => Ok(Some(self)),
75            #[cfg(feature = "python")]
76            Self::Python(v) => Ok(v
77                .try_into_initialized(clear_cached_credentials)?
78                .map(Self::Python)),
79        }
80    }
81
82    pub fn stable_cache_key(&self) -> PolarsResult<Vec<u8>> {
83        match self {
84            Self::Function(CredentialProviderFunction(v)) => Ok((Arc::as_ptr(v) as *const ()
85                as usize)
86                .to_ne_bytes()
87                .to_vec()),
88            #[cfg(feature = "python")]
89            Self::Python(v) => v.stable_cache_key(),
90        }
91    }
92}
93
94pub enum ObjectStoreCredential {
95    #[cfg(feature = "aws")]
96    Aws(Arc<object_store::aws::AwsCredential>),
97    #[cfg(feature = "azure")]
98    Azure(Arc<object_store::azure::AzureCredential>),
99    #[cfg(feature = "gcp")]
100    Gcp(Arc<object_store::gcp::GcpCredential>),
101    /// For testing purposes
102    None,
103}
104
105impl ObjectStoreCredential {
106    fn variant_name(&self) -> &'static str {
107        match self {
108            #[cfg(feature = "aws")]
109            Self::Aws(_) => "Aws",
110            #[cfg(feature = "azure")]
111            Self::Azure(_) => "Azure",
112            #[cfg(feature = "gcp")]
113            Self::Gcp(_) => "Gcp",
114            Self::None => "None",
115        }
116    }
117
118    fn panic_type_mismatch(&self, expected: &str) {
119        panic!(
120            "impl error: credential type mismatch: expected {}, got {} instead",
121            expected,
122            self.variant_name()
123        )
124    }
125
126    #[cfg(feature = "aws")]
127    fn unwrap_aws(self) -> Arc<object_store::aws::AwsCredential> {
128        let Self::Aws(v) = self else {
129            self.panic_type_mismatch("aws");
130            unreachable!()
131        };
132        v
133    }
134
135    #[cfg(feature = "azure")]
136    fn unwrap_azure(self) -> Arc<object_store::azure::AzureCredential> {
137        let Self::Azure(v) = self else {
138            self.panic_type_mismatch("azure");
139            unreachable!()
140        };
141        v
142    }
143
144    #[cfg(feature = "gcp")]
145    fn unwrap_gcp(self) -> Arc<object_store::gcp::GcpCredential> {
146        let Self::Gcp(v) = self else {
147            self.panic_type_mismatch("gcp");
148            unreachable!()
149        };
150        v
151    }
152}
153
154pub trait IntoCredentialProvider: Sized {
155    #[cfg(feature = "aws")]
156    fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
157        unimplemented!()
158    }
159
160    #[cfg(feature = "azure")]
161    fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
162        unimplemented!()
163    }
164
165    #[cfg(feature = "gcp")]
166    fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
167        unimplemented!()
168    }
169
170    /// Note, technically shouldn't be under the `IntoCredentialProvider` trait, but it's here
171    /// for convenience.
172    fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>>;
173}
174
175impl IntoCredentialProvider for PlCredentialProvider {
176    #[cfg(feature = "aws")]
177    fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
178        match self {
179            Self::Function(v) => v.into_aws_provider(),
180            #[cfg(feature = "python")]
181            Self::Python(v) => v.into_aws_provider(),
182        }
183    }
184
185    #[cfg(feature = "azure")]
186    fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
187        match self {
188            Self::Function(v) => v.into_azure_provider(),
189            #[cfg(feature = "python")]
190            Self::Python(v) => v.into_azure_provider(),
191        }
192    }
193
194    #[cfg(feature = "gcp")]
195    fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
196        match self {
197            Self::Function(v) => v.into_gcp_provider(),
198            #[cfg(feature = "python")]
199            Self::Python(v) => v.into_gcp_provider(),
200        }
201    }
202
203    fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>> {
204        match self {
205            Self::Function(v) => v.storage_update_options(),
206            #[cfg(feature = "python")]
207            Self::Python(v) => v.storage_update_options(),
208        }
209    }
210}
211
212type CredentialProviderFunctionImpl = Arc<
213    dyn Fn() -> Pin<
214            Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
215        > + Send
216        + Sync,
217>;
218
219/// Wrapper that implements [`IntoCredentialProvider`], [`Debug`], [`PartialEq`], [`Hash`] etc.
220#[derive(Clone)]
221pub struct CredentialProviderFunction(CredentialProviderFunctionImpl);
222
223macro_rules! build_to_object_store_err {
224    ($s:expr) => {{
225        fn to_object_store_err(
226            e: impl std::error::Error + Send + Sync + 'static,
227        ) -> object_store::Error {
228            object_store::Error::Generic {
229                store: $s,
230                source: Box::new(e),
231            }
232        }
233
234        to_object_store_err
235    }};
236}
237
238impl IntoCredentialProvider for CredentialProviderFunction {
239    #[cfg(feature = "aws")]
240    fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
241        #[derive(Debug)]
242        struct S(
243            CredentialProviderFunction,
244            FetchedCredentialsCache<Arc<object_store::aws::AwsCredential>>,
245        );
246
247        #[async_trait]
248        impl object_store::CredentialProvider for S {
249            type Credential = object_store::aws::AwsCredential;
250
251            async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
252                self.1
253                    .get_maybe_update(async {
254                        let (creds, expiry) = self.0.0().await?;
255                        PolarsResult::Ok((creds.unwrap_aws(), expiry))
256                    })
257                    .await
258                    .map_err(build_to_object_store_err!("credential-provider-aws"))
259            }
260        }
261
262        Arc::new(S(
263            self,
264            FetchedCredentialsCache::new(Arc::new(AwsCredential {
265                key_id: String::new(),
266                secret_key: String::new(),
267                token: None,
268            })),
269        ))
270    }
271
272    #[cfg(feature = "azure")]
273    fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
274        #[derive(Debug)]
275        struct S(
276            CredentialProviderFunction,
277            FetchedCredentialsCache<Arc<object_store::azure::AzureCredential>>,
278        );
279
280        #[async_trait]
281        impl object_store::CredentialProvider for S {
282            type Credential = object_store::azure::AzureCredential;
283
284            async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
285                self.1
286                    .get_maybe_update(async {
287                        let (creds, expiry) = self.0.0().await?;
288                        PolarsResult::Ok((creds.unwrap_azure(), expiry))
289                    })
290                    .await
291                    .map_err(build_to_object_store_err!("credential-provider-azure"))
292            }
293        }
294
295        Arc::new(S(
296            self,
297            FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))),
298        ))
299    }
300
301    #[cfg(feature = "gcp")]
302    fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
303        #[derive(Debug)]
304        struct S(
305            CredentialProviderFunction,
306            FetchedCredentialsCache<Arc<object_store::gcp::GcpCredential>>,
307        );
308
309        #[async_trait]
310        impl object_store::CredentialProvider for S {
311            type Credential = object_store::gcp::GcpCredential;
312
313            async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
314                self.1
315                    .get_maybe_update(async {
316                        let (creds, expiry) = self.0.0().await?;
317                        PolarsResult::Ok((creds.unwrap_gcp(), expiry))
318                    })
319                    .await
320                    .map_err(build_to_object_store_err!("credential-provider-gcp"))
321            }
322        }
323
324        Arc::new(S(
325            self,
326            FetchedCredentialsCache::new(Arc::new(GcpCredential {
327                bearer: String::new(),
328            })),
329        ))
330    }
331
332    fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>> {
333        Ok(vec![])
334    }
335}
336
337impl Debug for CredentialProviderFunction {
338    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
339        write!(
340            f,
341            "credential provider function at 0x{:016x}",
342            self.0.as_ref() as *const _ as *const () as usize
343        )
344    }
345}
346
347impl Eq for CredentialProviderFunction {}
348
349impl PartialEq for CredentialProviderFunction {
350    fn eq(&self, other: &Self) -> bool {
351        Arc::ptr_eq(&self.0, &other.0)
352    }
353}
354
355impl Hash for CredentialProviderFunction {
356    fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
357        state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
358    }
359}
360
361#[cfg(feature = "serde")]
362impl<'de> serde::Deserialize<'de> for PlCredentialProvider {
363    fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
364    where
365        D: serde::Deserializer<'de>,
366    {
367        #[cfg(feature = "python")]
368        {
369            Ok(Self::Python(PythonCredentialProvider::deserialize(
370                _deserializer,
371            )?))
372        }
373        #[cfg(not(feature = "python"))]
374        {
375            use serde::de::Error;
376            Err(D::Error::custom("cannot deserialize PlCredentialProvider"))
377        }
378    }
379}
380
381#[cfg(feature = "serde")]
382impl serde::Serialize for PlCredentialProvider {
383    fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
384    where
385        S: serde::Serializer,
386    {
387        use serde::ser::Error;
388
389        #[cfg(feature = "python")]
390        if let PlCredentialProvider::Python(v) = self {
391            return v.serialize(_serializer);
392        }
393
394        Err(S::Error::custom(format!("cannot serialize {self:?}")))
395    }
396}
397
398#[cfg(feature = "dsl-schema")]
399impl schemars::JsonSchema for PlCredentialProvider {
400    fn schema_name() -> std::borrow::Cow<'static, str> {
401        "PlCredentialProvider".into()
402    }
403
404    fn schema_id() -> std::borrow::Cow<'static, str> {
405        std::borrow::Cow::Borrowed(concat!(module_path!(), "::", "PlCredentialProvider"))
406    }
407
408    fn json_schema(generator: &mut schemars::SchemaGenerator) -> schemars::Schema {
409        Vec::<u8>::json_schema(generator)
410    }
411}
412
413/// Avoids calling the credential provider function if we have not yet passed the expiry time.
414#[derive(Debug)]
415struct FetchedCredentialsCache<C>(tokio::sync::Mutex<(C, u64, bool)>);
416
417impl<C: Clone> FetchedCredentialsCache<C> {
418    fn new(init_creds: C) -> Self {
419        Self(tokio::sync::Mutex::new((init_creds, 0, true)))
420    }
421
422    async fn get_maybe_update(
423        &self,
424        // Taking an `impl Future` here allows us to potentially avoid a `Box::pin` allocation from
425        // a `Fn() -> Pin<Box<dyn Future>>` by having it wrapped in an `async { f() }` block. We
426        // will not poll that block if the credentials have not yet expired.
427        update_func: impl Future<Output = PolarsResult<(C, u64)>>,
428    ) -> PolarsResult<C> {
429        let verbose = config::verbose();
430
431        fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String {
432            if last_fetched_expiry == u64::MAX {
433                "expiry = (never expires)".into()
434            } else {
435                format!(
436                    "expiry = {} (in {} seconds)",
437                    last_fetched_expiry,
438                    last_fetched_expiry.saturating_sub(now)
439                )
440            }
441        }
442
443        let mut inner = self.0.lock().await;
444        let (last_fetched_credentials, last_fetched_expiry, log_use_cached) = &mut *inner;
445
446        let current_time = SystemTime::now()
447            .duration_since(UNIX_EPOCH)
448            .unwrap()
449            .as_secs();
450
451        if *last_fetched_expiry <= current_time {
452            if verbose {
453                eprintln!(
454                    "[FetchedCredentialsCache]: \
455                    Call update_func: current_time = {}, \
456                    last_fetched_expiry = {}",
457                    current_time, *last_fetched_expiry
458                )
459            }
460
461            let (credentials, expiry) = update_func.await?;
462
463            *last_fetched_credentials = credentials;
464            *last_fetched_expiry = expiry;
465            *log_use_cached = true;
466
467            if expiry < current_time && expiry != 0 {
468                polars_bail!(
469                    ComputeError:
470                    "credential expiry time {} is older than system time {} \
471                     by {} seconds",
472                    expiry,
473                    current_time,
474                    current_time - expiry
475                )
476            }
477
478            if verbose {
479                eprintln!(
480                    "[FetchedCredentialsCache]: Finish update_func: new {}",
481                    expiry_msg(
482                        *last_fetched_expiry,
483                        SystemTime::now()
484                            .duration_since(UNIX_EPOCH)
485                            .unwrap()
486                            .as_secs()
487                    )
488                )
489            }
490        } else if verbose && *log_use_cached {
491            *log_use_cached = false;
492            let now = SystemTime::now()
493                .duration_since(UNIX_EPOCH)
494                .unwrap()
495                .as_secs();
496            eprintln!(
497                "[FetchedCredentialsCache]: Using cached credentials: \
498                current_time = {}, {}",
499                now,
500                expiry_msg(*last_fetched_expiry, now)
501            )
502        }
503
504        Ok(last_fetched_credentials.clone())
505    }
506}
507
508#[cfg(feature = "python")]
509mod python_impl {
510    use std::hash::Hash;
511    use std::sync::Arc;
512
513    use polars_error::{PolarsError, PolarsResult, polars_err};
514    use polars_utils::pl_str::PlSmallStr;
515    use polars_utils::python_function::PythonObject;
516    use pyo3::exceptions::PyValueError;
517    use pyo3::pybacked::PyBackedStr;
518    use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
519    use pyo3::{Python, intern};
520
521    use super::IntoCredentialProvider;
522
523    #[derive(Clone, Debug)]
524    #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
525    pub enum PythonCredentialProvider {
526        #[cfg_attr(
527            feature = "serde",
528            serde(
529                serialize_with = "PythonObject::serialize_with_pyversion",
530                deserialize_with = "PythonObject::deserialize_with_pyversion"
531            )
532        )]
533        /// Indicates `py_object` is a `CredentialProviderBuilder`.
534        Builder(Arc<PythonObject>),
535        #[cfg_attr(
536            feature = "serde",
537            serde(
538                serialize_with = "PythonObject::serialize_with_pyversion",
539                deserialize_with = "PythonObject::deserialize_with_pyversion"
540            )
541        )]
542        /// Indicates `py_object` is an instantiated credential provider
543        Provider(Arc<PythonObject>),
544    }
545
546    impl PythonCredentialProvider {
547        /// Performs initialization if necessary.
548        ///
549        /// This exists as a separate step that must be called beforehand. This approach is easier
550        /// as the alternative is to refactor the `IntoCredentialProvider` trait to return
551        /// `PolarsResult<Option<T>>` for every single function.
552        pub(super) fn try_into_initialized(
553            self,
554            clear_cached_credentials: bool,
555        ) -> PolarsResult<Option<Self>> {
556            match self {
557                Self::Builder(py_object) => {
558                    let opt_initialized_py_object = Python::attach(|py| {
559                        let build_fn =
560                            py_object.getattr(py, intern!(py, "build_credential_provider"))?;
561
562                        let v = build_fn.call1(py, (clear_cached_credentials,))?;
563                        let v = (!v.is_none(py)).then_some(v);
564
565                        pyo3::PyResult::Ok(v)
566                    })?;
567
568                    Ok(opt_initialized_py_object
569                        .map(PythonObject)
570                        .map(Arc::new)
571                        .map(Self::Provider))
572                },
573                Self::Provider(_) => {
574                    // Note: We don't expect to hit here.
575                    Ok(Some(self))
576                },
577            }
578        }
579
580        fn unwrap_as_provider(self) -> Arc<PythonObject> {
581            match self {
582                Self::Builder(_) => panic!(),
583                Self::Provider(v) => v,
584            }
585        }
586
587        pub(crate) fn unwrap_as_provider_ref(&self) -> &Arc<PythonObject> {
588            match self {
589                Self::Builder(_) => panic!(),
590                Self::Provider(v) => v,
591            }
592        }
593
594        pub(super) fn func_addr(&self) -> usize {
595            (match self {
596                Self::Builder(v) => Arc::as_ptr(v),
597                Self::Provider(v) => Arc::as_ptr(v),
598            }) as *const () as usize
599        }
600
601        pub fn stable_cache_key(&self) -> PolarsResult<Vec<u8>> {
602            let obj = match self {
603                Self::Builder(obj) | Self::Provider(obj) => obj,
604            };
605            let err = |e| {
606                polars_err!(ComputeError:
607                "failed to extract stable_cache_key: {e}")
608            };
609            Python::attach(|py| {
610                obj.call_method0(py, "stable_cache_key")
611                    .and_then(|r| r.extract::<Vec<u8>>(py))
612            })
613            .map_err(err)
614        }
615    }
616
617    impl IntoCredentialProvider for PythonCredentialProvider {
618        #[cfg(feature = "aws")]
619        fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
620            use polars_error::PolarsResult;
621
622            use crate::cloud::credential_provider::{
623                CredentialProviderFunction, ObjectStoreCredential,
624            };
625
626            let func = self.unwrap_as_provider();
627
628            CredentialProviderFunction(Arc::new(move || {
629                let func = func.clone();
630                Box::pin(async move {
631                    let mut credentials = object_store::aws::AwsCredential {
632                        key_id: String::new(),
633                        secret_key: String::new(),
634                        token: None,
635                    };
636
637                    let expiry = Python::attach(|py| {
638                        let v = func.0.call0(py)?.into_bound(py);
639                        let (storage_options, expiry) =
640                            v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
641
642                        for (k, v) in storage_options.iter() {
643                            let k = k.extract::<PyBackedStr>()?;
644                            let v = v.extract::<Option<String>>()?;
645
646                            match k.as_ref() {
647                                "aws_access_key_id" => {
648                                    credentials.key_id = v.ok_or_else(|| {
649                                        PyValueError::new_err("aws_access_key_id was None")
650                                    })?;
651                                },
652                                "aws_secret_access_key" => {
653                                    credentials.secret_key = v.ok_or_else(|| {
654                                        PyValueError::new_err("aws_secret_access_key was None")
655                                    })?
656                                },
657                                "aws_session_token" => credentials.token = v,
658                                v => {
659                                    return pyo3::PyResult::Err(PyValueError::new_err(format!(
660                                        "unknown configuration key for aws: {}, \
661                                    valid configuration keys are: \
662                                    {}, {}, {}",
663                                        v,
664                                        "aws_access_key_id",
665                                        "aws_secret_access_key",
666                                        "aws_session_token"
667                                    )));
668                                },
669                            }
670                        }
671
672                        pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
673                    })?;
674
675                    if credentials.key_id.is_empty() {
676                        return Err(PolarsError::ComputeError(
677                            "aws_access_key_id was empty or not given".into(),
678                        ));
679                    }
680
681                    if credentials.secret_key.is_empty() {
682                        return Err(PolarsError::ComputeError(
683                            "aws_secret_access_key was empty or not given".into(),
684                        ));
685                    }
686
687                    PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry))
688                })
689            }))
690            .into_aws_provider()
691        }
692
693        #[cfg(feature = "azure")]
694        fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
695            use object_store::azure::AzureAccessKey;
696            use percent_encoding::percent_decode_str;
697            use polars_core::config::verbose_print_sensitive;
698            use polars_error::PolarsResult;
699
700            use crate::cloud::credential_provider::{
701                CredentialProviderFunction, ObjectStoreCredential,
702            };
703
704            let func = self.unwrap_as_provider();
705
706            return CredentialProviderFunction(Arc::new(move || {
707                let func = func.clone();
708                Box::pin(async move {
709                    let mut credentials = None;
710
711                    static VALID_KEYS_MSG: &str =
712                        "valid configuration keys are: ('account_key', 'bearer_token', 'sas_token')";
713
714                    let expiry = Python::attach(|py| {
715                        let v = func.0.call0(py)?.into_bound(py);
716                        let (storage_options, expiry) =
717                            v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
718
719                        for (k, v) in storage_options.iter() {
720                            let k = k.extract::<PyBackedStr>()?;
721                            let v = v.extract::<String>()?;
722
723                            match k.as_ref() {
724                                "account_key" => {
725                                    credentials =
726                                        Some(object_store::azure::AzureCredential::AccessKey(
727                                            AzureAccessKey::try_new(v.as_str()).map_err(|e| {
728                                                PyValueError::new_err(e.to_string())
729                                            })?,
730                                        ))
731                                },
732                                "bearer_token" => {
733                                    credentials =
734                                        Some(object_store::azure::AzureCredential::BearerToken(v))
735                                },
736                                "sas_token" => {
737                                    credentials =
738                                        Some(object_store::azure::AzureCredential::SASToken(
739                                            split_sas(&v).map_err(|err_msg| {
740                                                verbose_print_sensitive(|| {
741                                                    format!("error decoding SAS token: {err_msg} (token: {v})")
742                                                });
743
744                                                PyValueError::new_err(format!(
745                                                    "error decoding SAS token: {err_msg}. \
746                                                    Set POLARS_VERBOSE_SENSITIVE=1 to print the value"
747                                                ))
748                                            })?,
749                                        ))
750                                },
751                                v => {
752                                    return pyo3::PyResult::Err(PyValueError::new_err(format!(
753                                        "unknown configuration key for azure: {v}, {VALID_KEYS_MSG}"
754                                    )));
755                                },
756                            }
757                        }
758
759                        pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
760                    })?;
761
762                    let Some(credentials) = credentials else {
763                        return Err(PolarsError::ComputeError(
764                            format!(
765                                "did not find a valid configuration key for azure, {VALID_KEYS_MSG}"
766                            )
767                            .into(),
768                        ));
769                    };
770
771                    PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry))
772                })
773            }))
774            .into_azure_provider();
775
776            /// Copied and adjusted from object-store.
777            ///
778            /// https://github.com/apache/arrow-rs-object-store/blob/7a0504b4924fcecee17d768fd7190b8f71b0877f/src/azure/builder.rs#L1072-L1089
779            fn split_sas(sas: &str) -> Result<Vec<(String, String)>, &'static str> {
780                let sas = percent_decode_str(sas)
781                    .decode_utf8()
782                    .map_err(|_| "UTF-8 decode error")?;
783
784                let kv_str_pairs = sas
785                    .trim_start_matches('?')
786                    .split('&')
787                    .filter(|s| !s.chars().all(char::is_whitespace));
788
789                let mut pairs = Vec::new();
790
791                for kv_pair_str in kv_str_pairs {
792                    let (k, v) = kv_pair_str
793                        .trim()
794                        .split_once('=')
795                        .ok_or("missing SAS component")?;
796                    pairs.push((k.into(), v.into()))
797                }
798
799                Ok(pairs)
800            }
801        }
802
803        #[cfg(feature = "gcp")]
804        fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
805            use polars_error::PolarsResult;
806
807            use crate::cloud::credential_provider::{
808                CredentialProviderFunction, ObjectStoreCredential,
809            };
810
811            let func = self.unwrap_as_provider();
812
813            CredentialProviderFunction(Arc::new(move || {
814                let func = func.clone();
815                Box::pin(async move {
816                    let mut credentials = object_store::gcp::GcpCredential {
817                        bearer: String::new(),
818                    };
819
820                    let expiry = Python::attach(|py| {
821                        let v = func.0.call0(py)?.into_bound(py);
822                        let (storage_options, expiry) =
823                            v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
824
825                        for (k, v) in storage_options.iter() {
826                            let k = k.extract::<PyBackedStr>()?;
827                            let v = v.extract::<String>()?;
828
829                            match k.as_ref() {
830                                "bearer_token" => credentials.bearer = v,
831                                v => {
832                                    return pyo3::PyResult::Err(PyValueError::new_err(format!(
833                                        "unknown configuration key for gcp: {}, \
834                                    valid configuration keys are: {}",
835                                        v, "bearer_token",
836                                    )));
837                                },
838                            }
839                        }
840
841                        pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
842                    })?;
843
844                    if credentials.bearer.is_empty() {
845                        return Err(PolarsError::ComputeError(
846                            "bearer was empty or not given".into(),
847                        ));
848                    }
849
850                    PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry))
851                })
852            }))
853            .into_gcp_provider()
854        }
855
856        /// # Panics
857        /// Panics if `self` is not an initialized provider.
858        fn storage_update_options(&self) -> PolarsResult<Vec<(PlSmallStr, PlSmallStr)>> {
859            let py_object = self.unwrap_as_provider_ref();
860
861            Python::attach(|py| {
862                py_object
863                    .getattr(py, "_storage_update_options")
864                    .map_or(Ok(vec![]), |f| {
865                        let v = f
866                            .call0(py)?
867                            .extract::<pyo3::Bound<'_, PyDict>>(py)
868                            .map_err(pyo3::PyErr::from)?;
869
870                        let mut out = Vec::with_capacity(v.len());
871
872                        for dict_item in v.call_method0("items")?.try_iter()? {
873                            let (key, value) =
874                                dict_item?.extract::<(PyBackedStr, PyBackedStr)>()?;
875
876                            out.push(((&*key).into(), (&*value).into()))
877                        }
878
879                        Ok(out)
880                    })
881            })
882        }
883    }
884
885    // Note: We don't consider `is_builder` for hash/eq - we don't expect the same Arc<PythonObject>
886    // to be referenced as both true and false from the `is_builder` field.
887
888    impl Eq for PythonCredentialProvider {}
889
890    impl PartialEq for PythonCredentialProvider {
891        fn eq(&self, other: &Self) -> bool {
892            self.func_addr() == other.func_addr()
893        }
894    }
895
896    impl Hash for PythonCredentialProvider {
897        fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
898            // # Safety
899            // * Inner is an `Arc`
900            // * Visibility is limited to super
901            // * No code in `mod python_impl` or `super` mutates the Arc inner.
902            state.write_usize(self.func_addr())
903        }
904    }
905}
906
907#[cfg(test)]
908mod tests {
909    #[cfg(feature = "serde")]
910    #[allow(clippy::redundant_pattern_matching)]
911    #[test]
912    fn test_serde() {
913        use super::*;
914
915        assert!(matches!(
916            serde_json::to_string(&Some(PlCredentialProvider::from_func(|| {
917                Box::pin(core::future::ready(PolarsResult::Ok((
918                    ObjectStoreCredential::None,
919                    0,
920                ))))
921            }))),
922            Err(_)
923        ));
924
925        assert!(matches!(
926            serde_json::to_string(&Option::<PlCredentialProvider>::None),
927            Ok(String { .. })
928        ));
929
930        assert!(matches!(
931            serde_json::from_str::<Option<PlCredentialProvider>>(
932                serde_json::to_string(&Option::<PlCredentialProvider>::None)
933                    .unwrap()
934                    .as_str()
935            ),
936            Ok(None)
937        ));
938    }
939}