1use std::fmt::Debug;
2use std::future::Future;
3use std::hash::Hash;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::{SystemTime, UNIX_EPOCH};
7
8use async_trait::async_trait;
9#[cfg(feature = "aws")]
10pub use object_store::aws::AwsCredential;
11#[cfg(feature = "azure")]
12pub use object_store::azure::AzureCredential;
13#[cfg(feature = "gcp")]
14pub use object_store::gcp::GcpCredential;
15use polars_core::config;
16use polars_error::{PolarsResult, polars_bail};
17#[cfg(feature = "python")]
18use polars_utils::python_function::PythonObject;
19#[cfg(feature = "python")]
20use python_impl::PythonCredentialProvider;
21
22#[derive(Clone, Debug, PartialEq, Hash, Eq)]
23pub enum PlCredentialProvider {
24 Function(CredentialProviderFunction),
26 #[cfg(feature = "python")]
27 Python(python_impl::PythonCredentialProvider),
28}
29
30impl PlCredentialProvider {
31 pub fn from_func(
35 func: impl Fn() -> Pin<
38 Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
39 > + Send
40 + Sync
41 + 'static,
42 ) -> Self {
43 Self::Function(CredentialProviderFunction(Arc::new(func)))
44 }
45
46 #[cfg(feature = "python")]
49 pub fn from_python_builder(func: pyo3::PyObject) -> Self {
50 Self::Python(python_impl::PythonCredentialProvider::Builder(Arc::new(
51 PythonObject(func),
52 )))
53 }
54
55 pub(super) fn func_addr(&self) -> usize {
56 match self {
57 Self::Function(CredentialProviderFunction(v)) => Arc::as_ptr(v) as *const () as usize,
58 #[cfg(feature = "python")]
59 Self::Python(v) => v.func_addr(),
60 }
61 }
62
63 pub(crate) fn try_into_initialized(self) -> PolarsResult<Option<Self>> {
68 match self {
69 Self::Function(_) => Ok(Some(self)),
70 #[cfg(feature = "python")]
71 Self::Python(v) => Ok(v.try_into_initialized()?.map(Self::Python)),
72 }
73 }
74}
75
76pub enum ObjectStoreCredential {
77 #[cfg(feature = "aws")]
78 Aws(Arc<object_store::aws::AwsCredential>),
79 #[cfg(feature = "azure")]
80 Azure(Arc<object_store::azure::AzureCredential>),
81 #[cfg(feature = "gcp")]
82 Gcp(Arc<object_store::gcp::GcpCredential>),
83 None,
85}
86
87impl ObjectStoreCredential {
88 fn variant_name(&self) -> &'static str {
89 match self {
90 #[cfg(feature = "aws")]
91 Self::Aws(_) => "Aws",
92 #[cfg(feature = "azure")]
93 Self::Azure(_) => "Azure",
94 #[cfg(feature = "gcp")]
95 Self::Gcp(_) => "Gcp",
96 Self::None => "None",
97 }
98 }
99
100 fn panic_type_mismatch(&self, expected: &str) {
101 panic!(
102 "impl error: credential type mismatch: expected {}, got {} instead",
103 expected,
104 self.variant_name()
105 )
106 }
107
108 #[cfg(feature = "aws")]
109 fn unwrap_aws(self) -> Arc<object_store::aws::AwsCredential> {
110 let Self::Aws(v) = self else {
111 self.panic_type_mismatch("aws");
112 unreachable!()
113 };
114 v
115 }
116
117 #[cfg(feature = "azure")]
118 fn unwrap_azure(self) -> Arc<object_store::azure::AzureCredential> {
119 let Self::Azure(v) = self else {
120 self.panic_type_mismatch("azure");
121 unreachable!()
122 };
123 v
124 }
125
126 #[cfg(feature = "gcp")]
127 fn unwrap_gcp(self) -> Arc<object_store::gcp::GcpCredential> {
128 let Self::Gcp(v) = self else {
129 self.panic_type_mismatch("gcp");
130 unreachable!()
131 };
132 v
133 }
134}
135
136pub trait IntoCredentialProvider: Sized {
137 #[cfg(feature = "aws")]
138 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
139 unimplemented!()
140 }
141
142 #[cfg(feature = "azure")]
143 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
144 unimplemented!()
145 }
146
147 #[cfg(feature = "gcp")]
148 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
149 unimplemented!()
150 }
151}
152
153impl IntoCredentialProvider for PlCredentialProvider {
154 #[cfg(feature = "aws")]
155 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
156 match self {
157 Self::Function(v) => v.into_aws_provider(),
158 #[cfg(feature = "python")]
159 Self::Python(v) => v.into_aws_provider(),
160 }
161 }
162
163 #[cfg(feature = "azure")]
164 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
165 match self {
166 Self::Function(v) => v.into_azure_provider(),
167 #[cfg(feature = "python")]
168 Self::Python(v) => v.into_azure_provider(),
169 }
170 }
171
172 #[cfg(feature = "gcp")]
173 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
174 match self {
175 Self::Function(v) => v.into_gcp_provider(),
176 #[cfg(feature = "python")]
177 Self::Python(v) => v.into_gcp_provider(),
178 }
179 }
180}
181
182type CredentialProviderFunctionImpl = Arc<
183 dyn Fn() -> Pin<
184 Box<dyn Future<Output = PolarsResult<(ObjectStoreCredential, u64)>> + Send + Sync>,
185 > + Send
186 + Sync,
187>;
188
189#[derive(Clone)]
191pub struct CredentialProviderFunction(CredentialProviderFunctionImpl);
192
193macro_rules! build_to_object_store_err {
194 ($s:expr) => {{
195 fn to_object_store_err(
196 e: impl std::error::Error + Send + Sync + 'static,
197 ) -> object_store::Error {
198 object_store::Error::Generic {
199 store: $s,
200 source: Box::new(e),
201 }
202 }
203
204 to_object_store_err
205 }};
206}
207
208impl IntoCredentialProvider for CredentialProviderFunction {
209 #[cfg(feature = "aws")]
210 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
211 #[derive(Debug)]
212 struct S(
213 CredentialProviderFunction,
214 FetchedCredentialsCache<Arc<object_store::aws::AwsCredential>>,
215 );
216
217 #[async_trait]
218 impl object_store::CredentialProvider for S {
219 type Credential = object_store::aws::AwsCredential;
220
221 async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
222 self.1
223 .get_maybe_update(async {
224 let (creds, expiry) = self.0.0().await?;
225 PolarsResult::Ok((creds.unwrap_aws(), expiry))
226 })
227 .await
228 .map_err(build_to_object_store_err!("credential-provider-aws"))
229 }
230 }
231
232 Arc::new(S(
233 self,
234 FetchedCredentialsCache::new(Arc::new(AwsCredential {
235 key_id: String::new(),
236 secret_key: String::new(),
237 token: None,
238 })),
239 ))
240 }
241
242 #[cfg(feature = "azure")]
243 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
244 #[derive(Debug)]
245 struct S(
246 CredentialProviderFunction,
247 FetchedCredentialsCache<Arc<object_store::azure::AzureCredential>>,
248 );
249
250 #[async_trait]
251 impl object_store::CredentialProvider for S {
252 type Credential = object_store::azure::AzureCredential;
253
254 async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
255 self.1
256 .get_maybe_update(async {
257 let (creds, expiry) = self.0.0().await?;
258 PolarsResult::Ok((creds.unwrap_azure(), expiry))
259 })
260 .await
261 .map_err(build_to_object_store_err!("credential-provider-azure"))
262 }
263 }
264
265 Arc::new(S(
266 self,
267 FetchedCredentialsCache::new(Arc::new(AzureCredential::BearerToken(String::new()))),
268 ))
269 }
270
271 #[cfg(feature = "gcp")]
272 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
273 #[derive(Debug)]
274 struct S(
275 CredentialProviderFunction,
276 FetchedCredentialsCache<Arc<object_store::gcp::GcpCredential>>,
277 );
278
279 #[async_trait]
280 impl object_store::CredentialProvider for S {
281 type Credential = object_store::gcp::GcpCredential;
282
283 async fn get_credential(&self) -> object_store::Result<Arc<Self::Credential>> {
284 self.1
285 .get_maybe_update(async {
286 let (creds, expiry) = self.0.0().await?;
287 PolarsResult::Ok((creds.unwrap_gcp(), expiry))
288 })
289 .await
290 .map_err(build_to_object_store_err!("credential-provider-gcp"))
291 }
292 }
293
294 Arc::new(S(
295 self,
296 FetchedCredentialsCache::new(Arc::new(GcpCredential {
297 bearer: String::new(),
298 })),
299 ))
300 }
301}
302
303impl Debug for CredentialProviderFunction {
304 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
305 write!(
306 f,
307 "credential provider function at 0x{:016x}",
308 self.0.as_ref() as *const _ as *const () as usize
309 )
310 }
311}
312
313impl Eq for CredentialProviderFunction {}
314
315impl PartialEq for CredentialProviderFunction {
316 fn eq(&self, other: &Self) -> bool {
317 Arc::ptr_eq(&self.0, &other.0)
318 }
319}
320
321impl Hash for CredentialProviderFunction {
322 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
323 state.write_usize(Arc::as_ptr(&self.0) as *const () as usize)
324 }
325}
326
327#[cfg(feature = "serde")]
328impl<'de> serde::Deserialize<'de> for PlCredentialProvider {
329 fn deserialize<D>(_deserializer: D) -> Result<Self, D::Error>
330 where
331 D: serde::Deserializer<'de>,
332 {
333 #[cfg(feature = "python")]
334 {
335 Ok(Self::Python(PythonCredentialProvider::deserialize(
336 _deserializer,
337 )?))
338 }
339 #[cfg(not(feature = "python"))]
340 {
341 use serde::de::Error;
342 Err(D::Error::custom("cannot deserialize PlCredentialProvider"))
343 }
344 }
345}
346
347#[cfg(feature = "serde")]
348impl serde::Serialize for PlCredentialProvider {
349 fn serialize<S>(&self, _serializer: S) -> Result<S::Ok, S::Error>
350 where
351 S: serde::Serializer,
352 {
353 use serde::ser::Error;
354
355 #[cfg(feature = "python")]
356 if let PlCredentialProvider::Python(v) = self {
357 return v.serialize(_serializer);
358 }
359
360 Err(S::Error::custom(format!("cannot serialize {:?}", self)))
361 }
362}
363
364#[derive(Debug)]
366struct FetchedCredentialsCache<C>(tokio::sync::Mutex<(C, u64)>);
367
368impl<C: Clone> FetchedCredentialsCache<C> {
369 fn new(init_creds: C) -> Self {
370 Self(tokio::sync::Mutex::new((init_creds, 0)))
371 }
372
373 async fn get_maybe_update(
374 &self,
375 update_func: impl Future<Output = PolarsResult<(C, u64)>>,
379 ) -> PolarsResult<C> {
380 let verbose = config::verbose();
381
382 fn expiry_msg(last_fetched_expiry: u64, now: u64) -> String {
383 if last_fetched_expiry == u64::MAX {
384 "expiry = (never expires)".into()
385 } else {
386 format!(
387 "expiry = {} (in {} seconds)",
388 last_fetched_expiry,
389 last_fetched_expiry.saturating_sub(now)
390 )
391 }
392 }
393
394 let mut inner = self.0.lock().await;
395 let (last_fetched_credentials, last_fetched_expiry) = &mut *inner;
396
397 let current_time = SystemTime::now()
398 .duration_since(UNIX_EPOCH)
399 .unwrap()
400 .as_secs();
401
402 const REQUEST_TIME_BUFFER: u64 = 7;
405
406 if last_fetched_expiry.saturating_sub(current_time) < REQUEST_TIME_BUFFER {
407 if verbose {
408 eprintln!(
409 "[FetchedCredentialsCache]: Call update_func: current_time = {}\
410 , last_fetched_expiry = {}",
411 current_time, *last_fetched_expiry
412 )
413 }
414 let (credentials, expiry) = update_func.await?;
415
416 *last_fetched_credentials = credentials;
417 *last_fetched_expiry = expiry;
418
419 if expiry < current_time && expiry != 0 {
420 polars_bail!(
421 ComputeError:
422 "credential expiry time {} is older than system time {} \
423 by {} seconds",
424 expiry,
425 current_time,
426 current_time - expiry
427 )
428 }
429
430 if verbose {
431 eprintln!(
432 "[FetchedCredentialsCache]: Finish update_func: new {}",
433 expiry_msg(
434 *last_fetched_expiry,
435 SystemTime::now()
436 .duration_since(UNIX_EPOCH)
437 .unwrap()
438 .as_secs()
439 )
440 )
441 }
442 } else if verbose {
443 let now = SystemTime::now()
444 .duration_since(UNIX_EPOCH)
445 .unwrap()
446 .as_secs();
447 eprintln!(
448 "[FetchedCredentialsCache]: Using cached credentials: \
449 current_time = {}, {}",
450 now,
451 expiry_msg(*last_fetched_expiry, now)
452 )
453 }
454
455 Ok(last_fetched_credentials.clone())
456 }
457}
458
459#[cfg(feature = "python")]
460mod python_impl {
461 use std::hash::Hash;
462 use std::sync::Arc;
463
464 use polars_error::{PolarsError, PolarsResult};
465 use polars_utils::python_function::PythonObject;
466 use pyo3::Python;
467 use pyo3::exceptions::PyValueError;
468 use pyo3::pybacked::PyBackedStr;
469 use pyo3::types::{PyAnyMethods, PyDict, PyDictMethods};
470
471 use super::IntoCredentialProvider;
472
473 #[derive(Clone, Debug)]
474 #[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
475 pub enum PythonCredentialProvider {
476 #[cfg_attr(
477 feature = "serde",
478 serde(
479 serialize_with = "PythonObject::serialize_with_pyversion",
480 deserialize_with = "PythonObject::deserialize_with_pyversion"
481 )
482 )]
483 Builder(Arc<PythonObject>),
485 #[cfg_attr(
486 feature = "serde",
487 serde(
488 serialize_with = "PythonObject::serialize_with_pyversion",
489 deserialize_with = "PythonObject::deserialize_with_pyversion"
490 )
491 )]
492 Provider(Arc<PythonObject>),
494 }
495
496 impl PythonCredentialProvider {
497 pub(super) fn try_into_initialized(self) -> PolarsResult<Option<Self>> {
503 match self {
504 Self::Builder(py_object) => {
505 let opt_initialized_py_object = Python::with_gil(|py| {
506 let build_fn = py_object.getattr(py, "build_credential_provider")?;
507
508 let v = build_fn.call0(py)?;
509 let v = (!v.is_none(py)).then_some(v);
510
511 pyo3::PyResult::Ok(v)
512 })?;
513
514 Ok(opt_initialized_py_object
515 .map(PythonObject)
516 .map(Arc::new)
517 .map(Self::Provider))
518 },
519 Self::Provider(_) => {
520 Ok(Some(self))
522 },
523 }
524 }
525
526 fn unwrap_as_provider(self) -> Arc<PythonObject> {
527 match self {
528 Self::Builder(_) => panic!(),
529 Self::Provider(v) => v,
530 }
531 }
532
533 pub(super) fn func_addr(&self) -> usize {
534 (match self {
535 Self::Builder(v) => Arc::as_ptr(v),
536 Self::Provider(v) => Arc::as_ptr(v),
537 }) as *const () as usize
538 }
539 }
540
541 impl IntoCredentialProvider for PythonCredentialProvider {
542 #[cfg(feature = "aws")]
543 fn into_aws_provider(self) -> object_store::aws::AwsCredentialProvider {
544 use polars_error::PolarsResult;
545
546 use crate::cloud::credential_provider::{
547 CredentialProviderFunction, ObjectStoreCredential,
548 };
549
550 let func = self.unwrap_as_provider();
551
552 CredentialProviderFunction(Arc::new(move || {
553 let func = func.clone();
554 Box::pin(async move {
555 let mut credentials = object_store::aws::AwsCredential {
556 key_id: String::new(),
557 secret_key: String::new(),
558 token: None,
559 };
560
561 let expiry = Python::with_gil(|py| {
562 let v = func.0.call0(py)?.into_bound(py);
563 let (storage_options, expiry) =
564 v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
565
566 for (k, v) in storage_options.iter() {
567 let k = k.extract::<PyBackedStr>()?;
568 let v = v.extract::<Option<String>>()?;
569
570 match k.as_ref() {
571 "aws_access_key_id" => {
572 credentials.key_id = v.ok_or_else(|| {
573 PyValueError::new_err("aws_access_key_id was None")
574 })?;
575 },
576 "aws_secret_access_key" => {
577 credentials.secret_key = v.ok_or_else(|| {
578 PyValueError::new_err("aws_secret_access_key was None")
579 })?
580 },
581 "aws_session_token" => credentials.token = v,
582 v => {
583 return pyo3::PyResult::Err(PyValueError::new_err(format!(
584 "unknown configuration key for aws: {}, \
585 valid configuration keys are: \
586 {}, {}, {}",
587 v,
588 "aws_access_key_id",
589 "aws_secret_access_key",
590 "aws_session_token"
591 )));
592 },
593 }
594 }
595
596 pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
597 })?;
598
599 if credentials.key_id.is_empty() {
600 return Err(PolarsError::ComputeError(
601 "aws_access_key_id was empty or not given".into(),
602 ));
603 }
604
605 if credentials.secret_key.is_empty() {
606 return Err(PolarsError::ComputeError(
607 "aws_secret_access_key was empty or not given".into(),
608 ));
609 }
610
611 PolarsResult::Ok((ObjectStoreCredential::Aws(Arc::new(credentials)), expiry))
612 })
613 }))
614 .into_aws_provider()
615 }
616
617 #[cfg(feature = "azure")]
618 fn into_azure_provider(self) -> object_store::azure::AzureCredentialProvider {
619 use object_store::azure::AzureAccessKey;
620 use polars_error::PolarsResult;
621
622 use crate::cloud::credential_provider::{
623 CredentialProviderFunction, ObjectStoreCredential,
624 };
625
626 let func = self.unwrap_as_provider();
627
628 CredentialProviderFunction(Arc::new(move || {
629 let func = func.clone();
630 Box::pin(async move {
631 let mut credentials = None;
632
633 static VALID_KEYS_MSG: &str =
634 "valid configuration keys are: account_key, bearer_token";
635
636 let expiry = Python::with_gil(|py| {
637 let v = func.0.call0(py)?.into_bound(py);
638 let (storage_options, expiry) =
639 v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
640
641 for (k, v) in storage_options.iter() {
642 let k = k.extract::<PyBackedStr>()?;
643 let v = v.extract::<String>()?;
644
645 match k.as_ref() {
646 "account_key" => {
647 credentials =
648 Some(object_store::azure::AzureCredential::AccessKey(
649 AzureAccessKey::try_new(v.as_str()).map_err(|e| {
650 PyValueError::new_err(e.to_string())
651 })?,
652 ))
653 },
654 "bearer_token" => {
655 credentials =
656 Some(object_store::azure::AzureCredential::BearerToken(v))
657 },
658 v => {
659 return pyo3::PyResult::Err(PyValueError::new_err(format!(
660 "unknown configuration key for azure: {}, {}",
661 v, VALID_KEYS_MSG
662 )));
663 },
664 }
665 }
666
667 pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
668 })?;
669
670 let Some(credentials) = credentials else {
671 return Err(PolarsError::ComputeError(
672 format!(
673 "did not find a valid configuration key for azure, {}",
674 VALID_KEYS_MSG
675 )
676 .into(),
677 ));
678 };
679
680 PolarsResult::Ok((ObjectStoreCredential::Azure(Arc::new(credentials)), expiry))
681 })
682 }))
683 .into_azure_provider()
684 }
685
686 #[cfg(feature = "gcp")]
687 fn into_gcp_provider(self) -> object_store::gcp::GcpCredentialProvider {
688 use polars_error::PolarsResult;
689
690 use crate::cloud::credential_provider::{
691 CredentialProviderFunction, ObjectStoreCredential,
692 };
693
694 let func = self.unwrap_as_provider();
695
696 CredentialProviderFunction(Arc::new(move || {
697 let func = func.clone();
698 Box::pin(async move {
699 let mut credentials = object_store::gcp::GcpCredential {
700 bearer: String::new(),
701 };
702
703 let expiry = Python::with_gil(|py| {
704 let v = func.0.call0(py)?.into_bound(py);
705 let (storage_options, expiry) =
706 v.extract::<(pyo3::Bound<'_, PyDict>, Option<u64>)>()?;
707
708 for (k, v) in storage_options.iter() {
709 let k = k.extract::<PyBackedStr>()?;
710 let v = v.extract::<String>()?;
711
712 match k.as_ref() {
713 "bearer_token" => credentials.bearer = v,
714 v => {
715 return pyo3::PyResult::Err(PyValueError::new_err(format!(
716 "unknown configuration key for gcp: {}, \
717 valid configuration keys are: {}",
718 v, "bearer_token",
719 )));
720 },
721 }
722 }
723
724 pyo3::PyResult::Ok(expiry.unwrap_or(u64::MAX))
725 })?;
726
727 if credentials.bearer.is_empty() {
728 return Err(PolarsError::ComputeError(
729 "bearer was empty or not given".into(),
730 ));
731 }
732
733 PolarsResult::Ok((ObjectStoreCredential::Gcp(Arc::new(credentials)), expiry))
734 })
735 }))
736 .into_gcp_provider()
737 }
738 }
739
740 impl Eq for PythonCredentialProvider {}
744
745 impl PartialEq for PythonCredentialProvider {
746 fn eq(&self, other: &Self) -> bool {
747 self.func_addr() == other.func_addr()
748 }
749 }
750
751 impl Hash for PythonCredentialProvider {
752 fn hash<H: std::hash::Hasher>(&self, state: &mut H) {
753 state.write_usize(self.func_addr())
758 }
759 }
760}
761
762#[cfg(test)]
763mod tests {
764 #[cfg(feature = "serde")]
765 #[allow(clippy::redundant_pattern_matching)]
766 #[test]
767 fn test_serde() {
768 use super::*;
769
770 assert!(matches!(
771 serde_json::to_string(&Some(PlCredentialProvider::from_func(|| {
772 Box::pin(core::future::ready(PolarsResult::Ok((
773 ObjectStoreCredential::None,
774 0,
775 ))))
776 }))),
777 Err(_)
778 ));
779
780 assert!(matches!(
781 serde_json::to_string(&Option::<PlCredentialProvider>::None),
782 Ok(String { .. })
783 ));
784
785 assert!(matches!(
786 serde_json::from_str::<Option<PlCredentialProvider>>(
787 serde_json::to_string(&Option::<PlCredentialProvider>::None)
788 .unwrap()
789 .as_str()
790 ),
791 Ok(None)
792 ));
793 }
794}