Skip to main content

floe_core/io/storage/
adls.rs

1use std::path::{Path, PathBuf};
2use std::sync::Arc;
3
4use azure_identity::{DefaultAzureCredential, TokenCredentialOptions};
5use azure_storage::StorageCredentials;
6use azure_storage_blobs::prelude::{BlobServiceClient, ContainerClient};
7use futures::StreamExt;
8use tokio::runtime::Runtime;
9
10use crate::errors::{RunError, StorageError};
11use crate::{config, ConfigError, FloeResult};
12
13use super::{planner, ObjectRef, StorageClient};
14
15pub struct AdlsClient {
16    account: String,
17    container: String,
18    prefix: String,
19    runtime: Runtime,
20    container_client: ContainerClient,
21}
22
23impl AdlsClient {
24    pub fn new(definition: &config::StorageDefinition) -> FloeResult<Self> {
25        let account = definition.account.clone().ok_or_else(|| {
26            Box::new(StorageError(format!(
27                "storage {} requires account for type adls",
28                definition.name
29            )))
30        })?;
31        let container = definition.container.clone().ok_or_else(|| {
32            Box::new(StorageError(format!(
33                "storage {} requires container for type adls",
34                definition.name
35            )))
36        })?;
37        let prefix = definition.prefix.clone().unwrap_or_default();
38        let runtime = tokio::runtime::Builder::new_current_thread()
39            .enable_all()
40            .build()
41            .map_err(|err| Box::new(StorageError(format!("adls runtime init failed: {err}"))))?;
42        let credential = DefaultAzureCredential::create(TokenCredentialOptions::default())
43            .map_err(|err| Box::new(StorageError(format!("adls credential init failed: {err}"))))?;
44        let storage_credentials = StorageCredentials::token_credential(Arc::new(credential));
45        let service_client = BlobServiceClient::new(account.clone(), storage_credentials);
46        let container_client = service_client.container_client(container.clone());
47        Ok(Self {
48            account,
49            container,
50            prefix,
51            runtime,
52            container_client,
53        })
54    }
55
56    fn base_prefix(&self) -> String {
57        planner::normalize_separators(&self.prefix)
58    }
59
60    fn full_path(&self, path: &str) -> String {
61        let prefix = self.base_prefix();
62        let joined = planner::join_prefix(&prefix, &planner::normalize_separators(path));
63        joined.trim_start_matches('/').to_string()
64    }
65
66    fn format_abfs(&self, path: &str) -> String {
67        format_abfs_uri(&self.container, &self.account, path)
68    }
69}
70
71impl StorageClient for AdlsClient {
72    fn list(&self, prefix_or_path: &str) -> FloeResult<Vec<ObjectRef>> {
73        let prefix = self.full_path(prefix_or_path);
74        let container = self.container.clone();
75        let account = self.account.clone();
76        let client = self.container_client.clone();
77        self.runtime.block_on(async move {
78            let mut refs = Vec::new();
79            let mut stream = client.list_blobs().prefix(prefix.clone()).into_stream();
80            while let Some(resp) = stream.next().await {
81                let resp = resp.map_err(|err| {
82                    Box::new(StorageError(format!("adls list failed: {err}")))
83                        as Box<dyn std::error::Error + Send + Sync>
84                })?;
85                for blob in resp.blobs.blobs() {
86                    let key = blob.name.clone();
87                    let uri = if key.is_empty() {
88                        format!("abfs://{}@{}.dfs.core.windows.net", container, account)
89                    } else {
90                        format!(
91                            "abfs://{}@{}.dfs.core.windows.net/{}",
92                            container, account, key
93                        )
94                    };
95                    refs.push(ObjectRef {
96                        uri,
97                        key,
98                        last_modified: Some(blob.properties.last_modified.to_string()),
99                        size: Some(blob.properties.content_length),
100                    });
101                }
102            }
103            Ok(planner::stable_sort_refs(refs))
104        })
105    }
106
107    fn download_to_temp(&self, uri: &str, temp_dir: &Path) -> FloeResult<PathBuf> {
108        let key = uri
109            .split_once(".dfs.core.windows.net/")
110            .map(|(_, tail)| tail)
111            .unwrap_or("")
112            .trim_start_matches('/')
113            .to_string();
114        let key = if key.is_empty() {
115            return Err(Box::new(StorageError(
116                "adls download requires a blob path".to_string(),
117            )));
118        } else {
119            key
120        };
121        let dest = planner::temp_path_for_key(temp_dir, &key);
122        let dest_clone = dest.clone();
123        let client = self.container_client.clone();
124        let key_clone = key.clone();
125        self.runtime.block_on(async move {
126            if let Some(parent) = dest_clone.parent() {
127                tokio::fs::create_dir_all(parent).await?;
128            }
129            let blob = client.blob_client(key_clone);
130            let mut stream = blob.get().into_stream();
131            let mut file = tokio::fs::File::create(&dest_clone).await?;
132            while let Some(chunk) = stream.next().await {
133                let resp = chunk.map_err(|err| {
134                    Box::new(StorageError(format!("adls download failed: {err}")))
135                        as Box<dyn std::error::Error + Send + Sync>
136                })?;
137                let bytes = resp.data.collect().await.map_err(|err| {
138                    Box::new(StorageError(format!("adls download read failed: {err}")))
139                        as Box<dyn std::error::Error + Send + Sync>
140                })?;
141                tokio::io::AsyncWriteExt::write_all(&mut file, &bytes).await?;
142            }
143            Ok::<(), Box<dyn std::error::Error + Send + Sync>>(())
144        })?;
145        Ok(dest)
146    }
147
148    fn upload_from_path(&self, local_path: &Path, uri: &str) -> FloeResult<()> {
149        let key = uri
150            .split_once(".dfs.core.windows.net/")
151            .map(|(_, tail)| tail)
152            .unwrap_or("")
153            .trim_start_matches('/')
154            .to_string();
155        if key.is_empty() {
156            return Err(Box::new(StorageError(
157                "adls upload requires a blob path".to_string(),
158            )));
159        }
160        let client = self.container_client.clone();
161        let path = local_path.to_path_buf();
162        self.runtime.block_on(async move {
163            let data = tokio::fs::read(path).await?;
164            let blob = client.blob_client(key);
165            blob.put_block_blob(data)
166                .content_type("application/octet-stream")
167                .into_future()
168                .await
169                .map_err(|err| {
170                    Box::new(StorageError(format!("adls upload failed: {err}")))
171                        as Box<dyn std::error::Error + Send + Sync>
172                })?;
173            Ok(())
174        })
175    }
176
177    fn resolve_uri(&self, path: &str) -> FloeResult<String> {
178        Ok(self.format_abfs(&self.full_path(path)))
179    }
180
181    fn copy_object(&self, src_uri: &str, dst_uri: &str) -> FloeResult<()> {
182        let temp_dir = tempfile::TempDir::new().map_err(|err| {
183            Box::new(StorageError(format!("adls tempdir failed: {err}")))
184                as Box<dyn std::error::Error + Send + Sync>
185        })?;
186        let temp_path = self.download_to_temp(src_uri, temp_dir.path())?;
187        self.upload_from_path(&temp_path, dst_uri)?;
188        Ok(())
189    }
190
191    fn delete_object(&self, uri: &str) -> FloeResult<()> {
192        let key = uri
193            .split_once(".dfs.core.windows.net/")
194            .map(|(_, tail)| tail)
195            .unwrap_or("")
196            .trim_start_matches('/')
197            .to_string();
198        if key.is_empty() {
199            return Ok(());
200        }
201        let client = self.container_client.clone();
202        self.runtime.block_on(async move {
203            let blob = client.blob_client(key);
204            blob.delete().into_future().await.map_err(|err| {
205                Box::new(StorageError(format!("adls delete failed: {err}")))
206                    as Box<dyn std::error::Error + Send + Sync>
207            })?;
208            Ok(())
209        })
210    }
211
212    fn exists(&self, uri: &str) -> FloeResult<bool> {
213        let key = uri
214            .split_once(".dfs.core.windows.net/")
215            .map(|(_, tail)| tail)
216            .unwrap_or("")
217            .trim_start_matches('/')
218            .to_string();
219        if key.is_empty() {
220            return Ok(false);
221        }
222        let refs = self.list(&key)?;
223        Ok(refs.iter().any(|object| object.key == key))
224    }
225}
226
227#[derive(Debug, Clone, PartialEq, Eq)]
228pub struct AdlsLocation {
229    pub account: String,
230    pub container: String,
231    pub path: String,
232}
233
234pub fn parse_adls_uri(uri: &str) -> FloeResult<AdlsLocation> {
235    let stripped = uri.strip_prefix("abfs://").ok_or_else(|| {
236        Box::new(ConfigError(format!("expected abfs uri, got {}", uri)))
237            as Box<dyn std::error::Error + Send + Sync>
238    })?;
239    let (container, rest) = stripped.split_once('@').ok_or_else(|| {
240        Box::new(ConfigError(format!(
241            "missing container in abfs uri: {}",
242            uri
243        ))) as Box<dyn std::error::Error + Send + Sync>
244    })?;
245    let (account, path) = rest.split_once(".dfs.core.windows.net").ok_or_else(|| {
246        Box::new(ConfigError(format!("missing account in abfs uri: {}", uri)))
247            as Box<dyn std::error::Error + Send + Sync>
248    })?;
249    let path = path.trim_start_matches('/');
250    Ok(AdlsLocation {
251        account: account.to_string(),
252        container: container.to_string(),
253        path: path.to_string(),
254    })
255}
256
257pub fn format_abfs_uri(container: &str, account: &str, path: &str) -> String {
258    let trimmed = path.trim_start_matches('/');
259    if trimmed.is_empty() {
260        format!("abfs://{}@{}.dfs.core.windows.net", container, account)
261    } else {
262        format!(
263            "abfs://{}@{}.dfs.core.windows.net/{}",
264            container, account, trimmed
265        )
266    }
267}
268
269pub fn build_input_files(
270    client: &dyn StorageClient,
271    container: &str,
272    account: &str,
273    prefix: &str,
274    adapter: &dyn crate::io::format::InputAdapter,
275    temp_dir: &Path,
276    entity: &crate::config::EntityConfig,
277    storage: &str,
278) -> FloeResult<Vec<crate::io::format::InputFile>> {
279    let suffixes = adapter.suffixes()?;
280    let list_refs = client.list(prefix)?;
281    let filtered = planner::filter_by_suffixes(list_refs, &suffixes);
282    let filtered = planner::stable_sort_refs(filtered);
283    if filtered.is_empty() {
284        return Err(Box::new(RunError(format!(
285            "entity.name={} source.storage={} no input objects matched (container={}, account={}, prefix={}, suffixes={})",
286            entity.name,
287            storage,
288            container,
289            account,
290            prefix,
291            suffixes.join(",")
292        ))));
293    }
294    let mut inputs = Vec::with_capacity(filtered.len());
295    for object in filtered {
296        let local_path = client.download_to_temp(&object.uri, temp_dir)?;
297        let source_name = crate::io::storage::s3::file_name_from_key(&object.key)
298            .unwrap_or_else(|| entity.name.clone());
299        let source_stem = crate::io::storage::s3::file_stem_from_name(&source_name)
300            .unwrap_or_else(|| entity.name.clone());
301        let source_uri = object.uri;
302        inputs.push(crate::io::format::InputFile {
303            source_uri,
304            source_local_path: local_path,
305            source_name,
306            source_stem,
307        });
308    }
309    Ok(inputs)
310}