Skip to main content

floe_core/io/write/strategy/
mod.rs

1use std::path::Path;
2
3use crate::io::storage::Target;
4use crate::{config, io, ConfigError, FloeResult};
5
6use super::parts;
7
8mod append;
9mod overwrite;
10
11#[derive(Debug, Clone, Copy)]
12pub enum PartScope {
13    Accepted { format: &'static str },
14    Rejected { format: &'static str },
15}
16
17#[derive(Debug, Clone, Copy)]
18pub struct PartSpec {
19    pub extension: &'static str,
20    pub scope: PartScope,
21}
22
23pub struct WriteContext<'a> {
24    pub target: &'a Target,
25    pub cloud: &'a mut io::storage::CloudClient,
26    pub resolver: &'a config::StorageResolver,
27    pub entity: &'a config::EntityConfig,
28}
29
30pub trait ModeStrategy {
31    fn mode(&self) -> config::WriteMode;
32    fn part_allocator(
33        &self,
34        ctx: &mut WriteContext<'_>,
35        spec: PartSpec,
36    ) -> FloeResult<parts::PartNameAllocator>;
37}
38
39#[derive(Debug, Clone)]
40enum CloudProvider {
41    S3,
42    Gcs { bucket: String },
43    Adls { container: String, account: String },
44}
45
46pub fn strategy_for(mode: config::WriteMode) -> &'static dyn ModeStrategy {
47    match mode {
48        config::WriteMode::Overwrite => &overwrite::OVERWRITE_STRATEGY,
49        config::WriteMode::Append => &append::APPEND_STRATEGY,
50    }
51}
52
53pub fn ensure_mode_supported(mode: config::WriteMode) -> FloeResult<()> {
54    match mode {
55        config::WriteMode::Overwrite => Ok(()),
56        config::WriteMode::Append => Ok(()),
57    }
58}
59
60pub fn accepted_parquet_spec() -> PartSpec {
61    PartSpec {
62        extension: "parquet",
63        scope: PartScope::Accepted { format: "parquet" },
64    }
65}
66
67pub fn rejected_csv_spec() -> PartSpec {
68    PartSpec {
69        extension: "csv",
70        scope: PartScope::Rejected { format: "csv" },
71    }
72}
73
74pub fn append_part_allocator(
75    _ctx: &mut WriteContext<'_>,
76    spec: PartSpec,
77) -> FloeResult<parts::PartNameAllocator> {
78    Ok(parts::PartNameAllocator::unique(spec.extension))
79}
80
81pub fn overwrite_part_allocator(
82    ctx: &mut WriteContext<'_>,
83    spec: PartSpec,
84) -> FloeResult<parts::PartNameAllocator> {
85    match ctx.target {
86        Target::Local { base_path, .. } => {
87            let base_path = Path::new(base_path);
88            let _ = parts::clear_local_part_files(base_path, spec.extension)?;
89            Ok(parts::PartNameAllocator::from_next_index(0, spec.extension))
90        }
91        Target::S3 { .. } | Target::Gcs { .. } | Target::Adls { .. } => {
92            clear_cloud_parts(ctx, spec)?;
93            Ok(parts::PartNameAllocator::from_next_index(0, spec.extension))
94        }
95    }
96}
97
98fn clear_cloud_parts(ctx: &mut WriteContext<'_>, spec: PartSpec) -> FloeResult<()> {
99    let (list_prefix, objects) = list_part_objects(ctx, spec)?;
100    let client = ctx
101        .cloud
102        .client_for(ctx.resolver, ctx.target.storage(), ctx.entity)?;
103    for object in objects
104        .into_iter()
105        .filter(|obj| obj.key.starts_with(&list_prefix))
106        .filter(|obj| parts::is_part_key(&obj.key, spec.extension))
107    {
108        client.delete_object(&object.uri)?;
109    }
110    Ok(())
111}
112
113pub(crate) fn list_part_objects(
114    ctx: &mut WriteContext<'_>,
115    spec: PartSpec,
116) -> FloeResult<(String, Vec<io::storage::ObjectRef>)> {
117    match ctx.target {
118        Target::S3 {
119            storage, base_key, ..
120        } => {
121            let provider = CloudProvider::S3;
122            let list_prefix = list_prefix(ctx.entity, base_key, &provider, spec)?;
123            let client = ctx.cloud.client_for(ctx.resolver, storage, ctx.entity)?;
124            let objects = client.list(&list_prefix)?;
125            Ok((list_prefix, objects))
126        }
127        Target::Gcs {
128            storage,
129            bucket,
130            base_key,
131            ..
132        } => {
133            let provider = CloudProvider::Gcs {
134                bucket: bucket.clone(),
135            };
136            let list_prefix = list_prefix(ctx.entity, base_key, &provider, spec)?;
137            let client = ctx.cloud.client_for(ctx.resolver, storage, ctx.entity)?;
138            let objects = client.list(&list_prefix)?;
139            Ok((list_prefix, objects))
140        }
141        Target::Adls {
142            storage,
143            container,
144            account,
145            base_path,
146            ..
147        } => {
148            let provider = CloudProvider::Adls {
149                container: container.clone(),
150                account: account.clone(),
151            };
152            let list_prefix = list_prefix(ctx.entity, base_path, &provider, spec)?;
153            let client = ctx.cloud.client_for(ctx.resolver, storage, ctx.entity)?;
154            let objects = client.list(&list_prefix)?;
155            Ok((list_prefix, objects))
156        }
157        Target::Local { .. } => Err(Box::new(ConfigError(
158            "cloud part listing requested for local target".to_string(),
159        ))),
160    }
161}
162
163fn list_prefix(
164    entity: &config::EntityConfig,
165    base_path: &str,
166    provider: &CloudProvider,
167    spec: PartSpec,
168) -> FloeResult<String> {
169    let prefix = base_path.trim_matches('/');
170    if prefix.is_empty() {
171        return Err(Box::new(prefix_error(entity, provider, spec)));
172    }
173    Ok(format!("{prefix}/"))
174}
175
176fn prefix_error(
177    entity: &config::EntityConfig,
178    provider: &CloudProvider,
179    spec: PartSpec,
180) -> ConfigError {
181    match (&spec.scope, provider) {
182        (PartScope::Accepted { format }, CloudProvider::S3) => ConfigError(format!(
183            "entity.name={} sink.accepted.path must not be bucket root for s3 {format} outputs",
184            entity.name
185        )),
186        (PartScope::Accepted { format }, CloudProvider::Gcs { bucket }) => ConfigError(format!(
187            "entity.name={} sink.accepted.path must not be bucket root for gcs {format} outputs (bucket={})",
188            entity.name, bucket
189        )),
190        (PartScope::Accepted { format }, CloudProvider::Adls { container, account }) => {
191            ConfigError(format!(
192                "entity.name={} sink.accepted.path must not be container root for adls {format} outputs (container={}, account={})",
193                entity.name, container, account
194            ))
195        }
196        (PartScope::Rejected { .. }, CloudProvider::S3) => ConfigError(format!(
197            "entity.name={} sink.rejected.path must not be bucket root for s3 outputs",
198            entity.name
199        )),
200        (PartScope::Rejected { .. }, CloudProvider::Gcs { .. }) => ConfigError(format!(
201            "entity.name={} sink.rejected.path must not be bucket root for gcs outputs",
202            entity.name
203        )),
204        (PartScope::Rejected { .. }, CloudProvider::Adls { .. }) => ConfigError(format!(
205            "entity.name={} sink.rejected.path must not be container root for adls outputs",
206            entity.name
207        )),
208    }
209}