Skip to main content

oxigdal_vrt/
band.rs

1//! VRT virtual band configuration
2
3use crate::error::{Result, VrtError};
4use crate::source::VrtSource;
5use oxigdal_core::types::{ColorInterpretation, NoDataValue, RasterDataType};
6use serde::{Deserialize, Serialize};
7
8/// VRT band configuration
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct VrtBand {
11    /// Band number (1-based)
12    pub band: usize,
13    /// Data type
14    pub data_type: RasterDataType,
15    /// Color interpretation
16    pub color_interp: ColorInterpretation,
17    /// NoData value
18    pub nodata: NoDataValue,
19    /// Sources for this band
20    pub sources: Vec<VrtSource>,
21    /// Block size (tile dimensions)
22    pub block_size: Option<(u32, u32)>,
23    /// Pixel function (for on-the-fly computation)
24    pub pixel_function: Option<PixelFunction>,
25    /// Offset for scaling
26    pub offset: Option<f64>,
27    /// Scale factor
28    pub scale: Option<f64>,
29    /// Color table
30    pub color_table: Option<ColorTable>,
31}
32
33impl VrtBand {
34    /// Creates a new VRT band
35    pub fn new(band: usize, data_type: RasterDataType) -> Self {
36        Self {
37            band,
38            data_type,
39            color_interp: ColorInterpretation::Undefined,
40            nodata: NoDataValue::None,
41            sources: Vec::new(),
42            block_size: None,
43            pixel_function: None,
44            offset: None,
45            scale: None,
46            color_table: None,
47        }
48    }
49
50    /// Creates a simple band with a single source
51    pub fn simple(band: usize, data_type: RasterDataType, source: VrtSource) -> Self {
52        Self {
53            band,
54            data_type,
55            color_interp: ColorInterpretation::Undefined,
56            nodata: source.nodata.unwrap_or(NoDataValue::None),
57            sources: vec![source],
58            block_size: None,
59            pixel_function: None,
60            offset: None,
61            scale: None,
62            color_table: None,
63        }
64    }
65
66    /// Adds a source to this band
67    pub fn add_source(&mut self, source: VrtSource) {
68        self.sources.push(source);
69    }
70
71    /// Sets the color interpretation
72    pub fn with_color_interp(mut self, color_interp: ColorInterpretation) -> Self {
73        self.color_interp = color_interp;
74        self
75    }
76
77    /// Sets the NoData value
78    pub fn with_nodata(mut self, nodata: NoDataValue) -> Self {
79        self.nodata = nodata;
80        self
81    }
82
83    /// Sets the block size
84    pub fn with_block_size(mut self, width: u32, height: u32) -> Self {
85        self.block_size = Some((width, height));
86        self
87    }
88
89    /// Sets the pixel function
90    pub fn with_pixel_function(mut self, function: PixelFunction) -> Self {
91        self.pixel_function = Some(function);
92        self
93    }
94
95    /// Sets the offset and scale
96    pub fn with_scaling(mut self, offset: f64, scale: f64) -> Self {
97        self.offset = Some(offset);
98        self.scale = Some(scale);
99        self
100    }
101
102    /// Sets the color table
103    pub fn with_color_table(mut self, color_table: ColorTable) -> Self {
104        self.color_table = Some(color_table);
105        self
106    }
107
108    /// Validates the band configuration
109    ///
110    /// # Errors
111    /// Returns an error if the band is invalid
112    pub fn validate(&self) -> Result<()> {
113        if self.band == 0 {
114            return Err(VrtError::invalid_band("Band number must be >= 1"));
115        }
116
117        if self.sources.is_empty() && self.pixel_function.is_none() {
118            return Err(VrtError::invalid_band(
119                "Band must have at least one source or a pixel function",
120            ));
121        }
122
123        // Validate all sources
124        for source in &self.sources {
125            source.validate()?;
126        }
127
128        // Validate pixel function if present
129        if let Some(ref func) = self.pixel_function {
130            func.validate(&self.sources)?;
131        }
132
133        Ok(())
134    }
135
136    /// Checks if this band has multiple sources
137    pub fn has_multiple_sources(&self) -> bool {
138        self.sources.len() > 1
139    }
140
141    /// Checks if this band uses a pixel function
142    pub fn uses_pixel_function(&self) -> bool {
143        self.pixel_function.is_some()
144    }
145
146    /// Applies scaling to a value
147    pub fn apply_scaling(&self, value: f64) -> f64 {
148        let scaled = if let Some(scale) = self.scale {
149            value * scale
150        } else {
151            value
152        };
153
154        if let Some(offset) = self.offset {
155            scaled + offset
156        } else {
157            scaled
158        }
159    }
160}
161
162/// Pixel function for on-the-fly computation
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub enum PixelFunction {
165    /// Average of all sources
166    Average,
167    /// Minimum of all sources
168    Min,
169    /// Maximum of all sources
170    Max,
171    /// Sum of all sources
172    Sum,
173    /// First valid (non-NoData) value
174    FirstValid,
175    /// Last valid (non-NoData) value
176    LastValid,
177    /// Weighted average (requires weights)
178    WeightedAverage {
179        /// Weights for each source
180        weights: Vec<f64>,
181    },
182    /// NDVI: (NIR - Red) / (NIR + Red)
183    /// Requires exactly 2 sources: [Red, NIR]
184    Ndvi,
185    /// Enhanced Vegetation Index: 2.5 * (NIR - Red) / (NIR + 6*Red - 7.5*Blue + 1)
186    /// Requires exactly 3 sources: [Red, NIR, Blue]
187    Evi,
188    /// Normalized Difference Water Index: (Green - NIR) / (Green + NIR)
189    /// Requires exactly 2 sources: [Green, NIR]
190    Ndwi,
191    /// Band math expression
192    /// Supports operations: +, -, *, /, sqrt, pow, abs, min, max
193    /// Variables are named as B1, B2, B3, etc. corresponding to source bands
194    BandMath {
195        /// Expression string (e.g., "(B1 + B2) / 2", "sqrt(B1 * B2)")
196        expression: String,
197    },
198    /// Lookup table transformation
199    /// Maps input values to output values
200    LookupTable {
201        /// Lookup table: vec of (input_value, output_value) pairs
202        table: Vec<(f64, f64)>,
203        /// Interpolation method: "nearest", "linear"
204        interpolation: String,
205    },
206    /// Conditional logic: if condition then value_if_true else value_if_false
207    /// Condition format: "B1 > 0.5", "B1 >= B2", etc.
208    Conditional {
209        /// Condition expression
210        condition: String,
211        /// Value if condition is true (can be a constant or expression)
212        value_if_true: String,
213        /// Value if condition is false (can be a constant or expression)
214        value_if_false: String,
215    },
216    /// Multiply source values
217    Multiply,
218    /// Divide first source by second (handles division by zero)
219    Divide,
220    /// Square root of source value
221    SquareRoot,
222    /// Absolute value of source value
223    Absolute,
224    /// Custom function (not yet implemented)
225    Custom {
226        /// Function name
227        name: String,
228    },
229}
230
231impl PixelFunction {
232    /// Validates the pixel function against sources
233    ///
234    /// # Errors
235    /// Returns an error if the function is invalid for the given sources
236    pub fn validate(&self, sources: &[VrtSource]) -> Result<()> {
237        match self {
238            Self::WeightedAverage { weights } => {
239                if weights.len() != sources.len() {
240                    return Err(VrtError::invalid_band(format!(
241                        "WeightedAverage requires {} weights, got {}",
242                        sources.len(),
243                        weights.len()
244                    )));
245                }
246
247                // Check that weights sum to approximately 1.0
248                let sum: f64 = weights.iter().sum();
249                if (sum - 1.0).abs() > 0.001 {
250                    return Err(VrtError::invalid_band(format!(
251                        "Weights should sum to 1.0, got {}",
252                        sum
253                    )));
254                }
255            }
256            Self::Ndvi | Self::Ndwi if sources.len() != 2 => {
257                return Err(VrtError::invalid_band(format!(
258                    "{:?} requires exactly 2 sources, got {}",
259                    self,
260                    sources.len()
261                )));
262            }
263            Self::Evi if sources.len() != 3 => {
264                return Err(VrtError::invalid_band(format!(
265                    "EVI requires exactly 3 sources, got {}",
266                    sources.len()
267                )));
268            }
269            Self::BandMath { expression } if expression.trim().is_empty() => {
270                return Err(VrtError::invalid_band(
271                    "BandMath expression cannot be empty",
272                ));
273            }
274            Self::LookupTable { table, .. } if table.is_empty() => {
275                return Err(VrtError::invalid_band("LookupTable cannot be empty"));
276            }
277            Self::Conditional {
278                condition,
279                value_if_true,
280                value_if_false,
281            } => {
282                if condition.trim().is_empty() {
283                    return Err(VrtError::invalid_band(
284                        "Conditional condition cannot be empty",
285                    ));
286                }
287                if value_if_true.trim().is_empty() || value_if_false.trim().is_empty() {
288                    return Err(VrtError::invalid_band("Conditional values cannot be empty"));
289                }
290            }
291            Self::Divide | Self::Multiply if sources.len() < 2 => {
292                return Err(VrtError::invalid_band(format!(
293                    "{:?} requires at least 2 sources, got {}",
294                    self,
295                    sources.len()
296                )));
297            }
298            Self::SquareRoot | Self::Absolute if sources.is_empty() => {
299                return Err(VrtError::invalid_band(format!(
300                    "{:?} requires at least 1 source",
301                    self
302                )));
303            }
304            Self::Custom { name } => {
305                return Err(VrtError::InvalidPixelFunction {
306                    function: name.clone(),
307                });
308            }
309            _ => {}
310        }
311        Ok(())
312    }
313
314    /// Applies the pixel function to a set of values
315    ///
316    /// # Errors
317    /// Returns an error if the function cannot be applied
318    pub fn apply(&self, values: &[Option<f64>]) -> Result<Option<f64>> {
319        match self {
320            Self::Average => {
321                let valid: Vec<f64> = values.iter().filter_map(|v| *v).collect();
322                if valid.is_empty() {
323                    Ok(None)
324                } else {
325                    Ok(Some(valid.iter().sum::<f64>() / valid.len() as f64))
326                }
327            }
328            Self::Min => Ok(values
329                .iter()
330                .filter_map(|v| *v)
331                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))),
332            Self::Max => Ok(values
333                .iter()
334                .filter_map(|v| *v)
335                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))),
336            Self::Sum => {
337                let valid: Vec<f64> = values.iter().filter_map(|v| *v).collect();
338                if valid.is_empty() {
339                    Ok(None)
340                } else {
341                    Ok(Some(valid.iter().sum()))
342                }
343            }
344            Self::FirstValid => Ok(values.iter().find_map(|v| *v)),
345            Self::LastValid => Ok(values.iter().rev().find_map(|v| *v)),
346            Self::WeightedAverage { weights } => {
347                if weights.len() != values.len() {
348                    return Err(VrtError::invalid_band("Weight count mismatch"));
349                }
350
351                let mut sum = 0.0;
352                let mut weight_sum = 0.0;
353
354                for (value, weight) in values.iter().zip(weights.iter()) {
355                    if let Some(v) = value {
356                        sum += v * weight;
357                        weight_sum += weight;
358                    }
359                }
360
361                if weight_sum > 0.0 {
362                    Ok(Some(sum / weight_sum))
363                } else {
364                    Ok(None)
365                }
366            }
367            Self::Ndvi => {
368                // NDVI = (NIR - Red) / (NIR + Red)
369                if values.len() != 2 {
370                    return Err(VrtError::invalid_band("NDVI requires exactly 2 values"));
371                }
372                match (values[0], values[1]) {
373                    (Some(red), Some(nir)) => {
374                        let denominator = nir + red;
375                        if denominator.abs() < f64::EPSILON {
376                            Ok(None) // Avoid division by zero
377                        } else {
378                            Ok(Some((nir - red) / denominator))
379                        }
380                    }
381                    _ => Ok(None),
382                }
383            }
384            Self::Evi => {
385                // EVI = 2.5 * (NIR - Red) / (NIR + 6*Red - 7.5*Blue + 1)
386                if values.len() != 3 {
387                    return Err(VrtError::invalid_band("EVI requires exactly 3 values"));
388                }
389                match (values[0], values[1], values[2]) {
390                    (Some(red), Some(nir), Some(blue)) => {
391                        let denominator = nir + 6.0 * red - 7.5 * blue + 1.0;
392                        if denominator.abs() < f64::EPSILON {
393                            Ok(None)
394                        } else {
395                            Ok(Some(2.5 * (nir - red) / denominator))
396                        }
397                    }
398                    _ => Ok(None),
399                }
400            }
401            Self::Ndwi => {
402                // NDWI = (Green - NIR) / (Green + NIR)
403                if values.len() != 2 {
404                    return Err(VrtError::invalid_band("NDWI requires exactly 2 values"));
405                }
406                match (values[0], values[1]) {
407                    (Some(green), Some(nir)) => {
408                        let denominator = green + nir;
409                        if denominator.abs() < f64::EPSILON {
410                            Ok(None)
411                        } else {
412                            Ok(Some((green - nir) / denominator))
413                        }
414                    }
415                    _ => Ok(None),
416                }
417            }
418            Self::BandMath { expression } => Self::evaluate_expression(expression, values),
419            Self::LookupTable {
420                table,
421                interpolation,
422            } => {
423                if values.is_empty() {
424                    return Ok(None);
425                }
426                if let Some(value) = values[0] {
427                    Self::apply_lookup_table(value, table, interpolation)
428                } else {
429                    Ok(None)
430                }
431            }
432            Self::Conditional {
433                condition,
434                value_if_true,
435                value_if_false,
436            } => Self::evaluate_conditional(condition, value_if_true, value_if_false, values),
437            Self::Multiply => {
438                let valid: Vec<f64> = values.iter().filter_map(|v| *v).collect();
439                if valid.is_empty() {
440                    Ok(None)
441                } else {
442                    Ok(Some(valid.iter().product()))
443                }
444            }
445            Self::Divide => {
446                if values.len() < 2 {
447                    return Err(VrtError::invalid_band("Divide requires at least 2 values"));
448                }
449                match (values[0], values[1]) {
450                    (Some(numerator), Some(denominator)) => {
451                        if denominator.abs() < f64::EPSILON {
452                            Ok(None) // Avoid division by zero
453                        } else {
454                            Ok(Some(numerator / denominator))
455                        }
456                    }
457                    _ => Ok(None),
458                }
459            }
460            Self::SquareRoot => {
461                if values.is_empty() {
462                    return Ok(None);
463                }
464                values[0].map_or(Ok(None), |v| {
465                    if v < 0.0 {
466                        Ok(None) // Negative values have no real square root
467                    } else {
468                        Ok(Some(v.sqrt()))
469                    }
470                })
471            }
472            Self::Absolute => {
473                if values.is_empty() {
474                    return Ok(None);
475                }
476                Ok(values[0].map(|v| v.abs()))
477            }
478            Self::Custom { name } => Err(VrtError::InvalidPixelFunction {
479                function: name.clone(),
480            }),
481        }
482    }
483
484    /// Evaluates a band math expression
485    fn evaluate_expression(expression: &str, values: &[Option<f64>]) -> Result<Option<f64>> {
486        // Simple expression evaluator for basic band math
487        // Supports: +, -, *, /, sqrt, pow, abs, min, max
488        // Variables: B1, B2, B3, etc.
489
490        let mut expr = expression.to_string();
491
492        // Replace band variables with actual values
493        for (i, value) in values.iter().enumerate() {
494            let var = format!("B{}", i + 1);
495            if let Some(v) = value {
496                expr = expr.replace(&var, &v.to_string());
497            } else {
498                return Ok(None); // If any band is NoData, result is NoData
499            }
500        }
501
502        // Basic expression evaluation (simplified)
503        // For production, consider using a proper expression parser like `evalexpr`
504        match Self::simple_eval(&expr) {
505            Ok(result) => Ok(Some(result)),
506            Err(_) => Err(VrtError::invalid_band(format!(
507                "Failed to evaluate expression: {}",
508                expression
509            ))),
510        }
511    }
512
513    /// Simple expression evaluator (basic implementation)
514    fn simple_eval(expr: &str) -> Result<f64> {
515        let expr = expr.trim();
516
517        // Try to parse as number first
518        if let Ok(num) = expr.parse::<f64>() {
519            return Ok(num);
520        }
521
522        // Handle sqrt
523        if expr.starts_with("sqrt(") && expr.ends_with(')') {
524            let inner = &expr[5..expr.len() - 1];
525            let val = Self::simple_eval(inner)?;
526            if val < 0.0 {
527                return Err(VrtError::invalid_band("Square root of negative number"));
528            }
529            return Ok(val.sqrt());
530        }
531
532        // Handle abs
533        if expr.starts_with("abs(") && expr.ends_with(')') {
534            let inner = &expr[4..expr.len() - 1];
535            let val = Self::simple_eval(inner)?;
536            return Ok(val.abs());
537        }
538
539        // Handle parentheses
540        if expr.starts_with('(') && expr.ends_with(')') {
541            // Check if these are balanced outer parens
542            let inner = &expr[1..expr.len() - 1];
543            let mut depth = 0;
544            let mut is_outer = true;
545            for ch in inner.chars() {
546                if ch == '(' {
547                    depth += 1;
548                } else if ch == ')' {
549                    depth -= 1;
550                    if depth < 0 {
551                        is_outer = false;
552                        break;
553                    }
554                }
555            }
556            if is_outer && depth == 0 {
557                return Self::simple_eval(inner);
558            }
559        }
560
561        // Handle binary operations (search from right to left for left-to-right evaluation)
562        // Process + and - first (lower precedence), then * and / (higher precedence)
563        for op in &['+', '-'] {
564            let mut depth = 0;
565            for (i, ch) in expr.char_indices().rev() {
566                if ch == ')' {
567                    depth += 1;
568                } else if ch == '(' {
569                    depth -= 1;
570                } else if depth == 0 && ch == *op && i > 0 && i < expr.len() - 1 {
571                    let left = Self::simple_eval(&expr[..i])?;
572                    let right = Self::simple_eval(&expr[i + 1..])?;
573                    return match op {
574                        '+' => Ok(left + right),
575                        '-' => Ok(left - right),
576                        _ => unreachable!(),
577                    };
578                }
579            }
580        }
581
582        for op in &['*', '/'] {
583            let mut depth = 0;
584            for (i, ch) in expr.char_indices().rev() {
585                if ch == ')' {
586                    depth += 1;
587                } else if ch == '(' {
588                    depth -= 1;
589                } else if depth == 0 && ch == *op && i > 0 && i < expr.len() - 1 {
590                    let left = Self::simple_eval(&expr[..i])?;
591                    let right = Self::simple_eval(&expr[i + 1..])?;
592                    return match op {
593                        '*' => Ok(left * right),
594                        '/' => {
595                            if right.abs() < f64::EPSILON {
596                                Err(VrtError::invalid_band("Division by zero"))
597                            } else {
598                                Ok(left / right)
599                            }
600                        }
601                        _ => unreachable!(),
602                    };
603                }
604            }
605        }
606
607        Err(VrtError::invalid_band(format!(
608            "Cannot parse expression: {}",
609            expr
610        )))
611    }
612
613    /// Applies lookup table transformation
614    fn apply_lookup_table(
615        value: f64,
616        table: &[(f64, f64)],
617        interpolation: &str,
618    ) -> Result<Option<f64>> {
619        if table.is_empty() {
620            return Ok(None);
621        }
622
623        match interpolation {
624            "nearest" => {
625                // Find nearest entry
626                let mut best_idx = 0;
627                let mut best_dist = (table[0].0 - value).abs();
628
629                for (i, (input, _)) in table.iter().enumerate() {
630                    let dist = (input - value).abs();
631                    if dist < best_dist {
632                        best_dist = dist;
633                        best_idx = i;
634                    }
635                }
636
637                Ok(Some(table[best_idx].1))
638            }
639            "linear" => {
640                // Linear interpolation
641                if value <= table[0].0 {
642                    return Ok(Some(table[0].1));
643                }
644                if value >= table[table.len() - 1].0 {
645                    return Ok(Some(table[table.len() - 1].1));
646                }
647
648                // Find surrounding points
649                for i in 0..table.len() - 1 {
650                    if value >= table[i].0 && value <= table[i + 1].0 {
651                        let x0 = table[i].0;
652                        let y0 = table[i].1;
653                        let x1 = table[i + 1].0;
654                        let y1 = table[i + 1].1;
655
656                        let t = (value - x0) / (x1 - x0);
657                        return Ok(Some(y0 + t * (y1 - y0)));
658                    }
659                }
660
661                Ok(Some(table[0].1))
662            }
663            _ => Err(VrtError::invalid_band(format!(
664                "Unknown interpolation method: {}",
665                interpolation
666            ))),
667        }
668    }
669
670    /// Evaluates conditional expression
671    fn evaluate_conditional(
672        condition: &str,
673        value_if_true: &str,
674        value_if_false: &str,
675        values: &[Option<f64>],
676    ) -> Result<Option<f64>> {
677        // Simple condition evaluator
678        // Supports: >, <, >=, <=, ==, !=
679
680        let cond_result = Self::evaluate_condition(condition, values)?;
681
682        let target_expr = if cond_result {
683            value_if_true
684        } else {
685            value_if_false
686        };
687
688        // Evaluate the target expression
689        Self::evaluate_expression(target_expr, values)
690    }
691
692    /// Evaluates a boolean condition
693    fn evaluate_condition(condition: &str, values: &[Option<f64>]) -> Result<bool> {
694        let condition = condition.trim();
695
696        // Try different comparison operators in order (longer ones first to avoid false matches)
697        let operators = [">=", "<=", "==", "!=", ">", "<"];
698
699        for op_str in &operators {
700            if let Some(pos) = condition.find(op_str) {
701                let left_expr = condition[..pos].trim();
702                let right_expr = condition[pos + op_str.len()..].trim();
703
704                let left_val = Self::evaluate_expression(left_expr, values)?
705                    .ok_or_else(|| VrtError::invalid_band("Left side of condition is NoData"))?;
706
707                let right_val = Self::evaluate_expression(right_expr, values)?
708                    .ok_or_else(|| VrtError::invalid_band("Right side of condition is NoData"))?;
709
710                let result = match *op_str {
711                    ">=" => left_val >= right_val,
712                    "<=" => left_val <= right_val,
713                    ">" => left_val > right_val,
714                    "<" => left_val < right_val,
715                    "==" => (left_val - right_val).abs() < f64::EPSILON,
716                    "!=" => (left_val - right_val).abs() >= f64::EPSILON,
717                    _ => {
718                        return Err(VrtError::invalid_band(format!(
719                            "Unknown operator: {}",
720                            op_str
721                        )));
722                    }
723                };
724
725                return Ok(result);
726            }
727        }
728
729        Err(VrtError::invalid_band(format!(
730            "Cannot parse condition: {}",
731            condition
732        )))
733    }
734}
735
736/// Color table entry
737#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
738pub struct ColorEntry {
739    /// Color value (index)
740    pub value: u16,
741    /// Red component (0-255)
742    pub r: u8,
743    /// Green component (0-255)
744    pub g: u8,
745    /// Blue component (0-255)
746    pub b: u8,
747    /// Alpha component (0-255)
748    pub a: u8,
749}
750
751impl ColorEntry {
752    /// Creates a new color entry
753    pub const fn new(value: u16, r: u8, g: u8, b: u8, a: u8) -> Self {
754        Self { value, r, g, b, a }
755    }
756
757    /// Creates an opaque color entry
758    pub const fn rgb(value: u16, r: u8, g: u8, b: u8) -> Self {
759        Self::new(value, r, g, b, 255)
760    }
761}
762
763/// Color table (palette)
764#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
765pub struct ColorTable {
766    /// Color entries
767    pub entries: Vec<ColorEntry>,
768}
769
770impl ColorTable {
771    /// Creates a new empty color table
772    pub fn new() -> Self {
773        Self {
774            entries: Vec::new(),
775        }
776    }
777
778    /// Creates a color table with entries
779    pub fn with_entries(entries: Vec<ColorEntry>) -> Self {
780        Self { entries }
781    }
782
783    /// Adds a color entry
784    pub fn add_entry(&mut self, entry: ColorEntry) {
785        self.entries.push(entry);
786    }
787
788    /// Gets a color entry by value
789    pub fn get(&self, value: u16) -> Option<&ColorEntry> {
790        self.entries.iter().find(|e| e.value == value)
791    }
792}
793
794impl Default for ColorTable {
795    fn default() -> Self {
796        Self::new()
797    }
798}
799
800#[cfg(test)]
801mod tests {
802    use super::*;
803    use crate::source::SourceFilename;
804
805    #[test]
806    fn test_vrt_band_creation() {
807        let band = VrtBand::new(1, RasterDataType::UInt8);
808        assert_eq!(band.band, 1);
809        assert_eq!(band.data_type, RasterDataType::UInt8);
810    }
811
812    #[test]
813    fn test_vrt_band_validation() {
814        let source = VrtSource::new(SourceFilename::absolute("/test.tif"), 1);
815        let band = VrtBand::simple(1, RasterDataType::UInt8, source);
816        assert!(band.validate().is_ok());
817
818        let invalid_band = VrtBand::new(0, RasterDataType::UInt8);
819        assert!(invalid_band.validate().is_err());
820
821        let no_source_band = VrtBand::new(1, RasterDataType::UInt8);
822        assert!(no_source_band.validate().is_err());
823    }
824
825    #[test]
826    fn test_pixel_function_average() {
827        let func = PixelFunction::Average;
828        let values = vec![Some(1.0), Some(2.0), Some(3.0)];
829        let result = func.apply(&values);
830        assert!(result.is_ok());
831        assert_eq!(result.ok().flatten(), Some(2.0));
832
833        let with_none = vec![Some(1.0), None, Some(3.0)];
834        let result = func.apply(&with_none);
835        assert!(result.is_ok());
836        assert_eq!(result.ok().flatten(), Some(2.0));
837    }
838
839    #[test]
840    fn test_pixel_function_first_valid() {
841        let func = PixelFunction::FirstValid;
842        let values = vec![None, Some(2.0), Some(3.0)];
843        let result = func.apply(&values);
844        assert!(result.is_ok());
845        assert_eq!(result.ok().flatten(), Some(2.0));
846    }
847
848    #[test]
849    fn test_pixel_function_weighted_average() {
850        let func = PixelFunction::WeightedAverage {
851            weights: vec![0.5, 0.3, 0.2],
852        };
853        let values = vec![Some(10.0), Some(20.0), Some(30.0)];
854        let result = func.apply(&values);
855        assert!(result.is_ok());
856        // 10*0.5 + 20*0.3 + 30*0.2 = 5 + 6 + 6 = 17
857        assert_eq!(result.ok().flatten(), Some(17.0));
858    }
859
860    #[test]
861    fn test_band_scaling() {
862        let band = VrtBand::new(1, RasterDataType::UInt8).with_scaling(10.0, 2.0);
863
864        assert_eq!(band.apply_scaling(5.0), 20.0); // 5 * 2 + 10 = 20
865    }
866
867    #[test]
868    fn test_color_table() {
869        let mut table = ColorTable::new();
870        table.add_entry(ColorEntry::rgb(0, 255, 0, 0));
871        table.add_entry(ColorEntry::rgb(1, 0, 255, 0));
872        table.add_entry(ColorEntry::rgb(2, 0, 0, 255));
873
874        assert_eq!(table.entries.len(), 3);
875        assert_eq!(table.get(1).map(|e| e.g), Some(255));
876    }
877
878    #[test]
879    fn test_pixel_function_ndvi() {
880        let func = PixelFunction::Ndvi;
881
882        // Standard NDVI calculation
883        let values = vec![Some(0.1), Some(0.5)]; // Red, NIR
884        let result = func.apply(&values);
885        assert!(result.is_ok());
886        // NDVI = (0.5 - 0.1) / (0.5 + 0.1) = 0.4 / 0.6 = 0.666...
887        assert!((result.ok().flatten().expect("Should have value") - 0.666666).abs() < 0.001);
888
889        // With NoData
890        let values_nodata = vec![Some(0.1), None];
891        let result = func.apply(&values_nodata);
892        assert!(result.is_ok());
893        assert_eq!(result.ok().flatten(), None);
894
895        // Zero sum (edge case)
896        let values_zero = vec![Some(0.5), Some(-0.5)];
897        let result = func.apply(&values_zero);
898        assert!(result.is_ok());
899        assert_eq!(result.ok().flatten(), None); // Should return None to avoid division by zero
900    }
901
902    #[test]
903    fn test_pixel_function_evi() {
904        let func = PixelFunction::Evi;
905
906        // Standard EVI calculation
907        let values = vec![Some(0.1), Some(0.5), Some(0.05)]; // Red, NIR, Blue
908        let result = func.apply(&values);
909        assert!(result.is_ok());
910        // EVI = 2.5 * (0.5 - 0.1) / (0.5 + 6*0.1 - 7.5*0.05 + 1)
911        // = 2.5 * 0.4 / (0.5 + 0.6 - 0.375 + 1)
912        // = 1.0 / 1.725 = 0.5797...
913        let expected = 1.0 / 1.725;
914        assert!((result.ok().flatten().expect("Should have value") - expected).abs() < 0.001);
915    }
916
917    #[test]
918    fn test_pixel_function_ndwi() {
919        let func = PixelFunction::Ndwi;
920
921        // Standard NDWI calculation
922        let values = vec![Some(0.3), Some(0.2)]; // Green, NIR
923        let result = func.apply(&values);
924        assert!(result.is_ok());
925        // NDWI = (0.3 - 0.2) / (0.3 + 0.2) = 0.1 / 0.5 = 0.2
926        assert!((result.ok().flatten().expect("Should have value") - 0.2).abs() < 0.001);
927    }
928
929    #[test]
930    fn test_pixel_function_band_math() {
931        let func = PixelFunction::BandMath {
932            expression: "(B1 + B2) / 2".to_string(),
933        };
934
935        let values = vec![Some(10.0), Some(20.0)];
936        let result = func.apply(&values);
937        assert!(result.is_ok());
938        assert_eq!(result.ok().flatten(), Some(15.0));
939
940        // Test with sqrt
941        let func_sqrt = PixelFunction::BandMath {
942            expression: "sqrt(B1)".to_string(),
943        };
944        let values_sqrt = vec![Some(16.0)];
945        let result = func_sqrt.apply(&values_sqrt);
946        assert!(result.is_ok());
947        assert_eq!(result.ok().flatten(), Some(4.0));
948
949        // Test with abs
950        let func_abs = PixelFunction::BandMath {
951            expression: "abs(B1)".to_string(),
952        };
953        let values_abs = vec![Some(-5.0)];
954        let result = func_abs.apply(&values_abs);
955        assert!(result.is_ok());
956        assert_eq!(result.ok().flatten(), Some(5.0));
957    }
958
959    #[test]
960    fn test_pixel_function_lookup_table_nearest() {
961        let func = PixelFunction::LookupTable {
962            table: vec![(0.0, 10.0), (0.5, 20.0), (1.0, 30.0)],
963            interpolation: "nearest".to_string(),
964        };
965
966        // Exact match
967        let values = vec![Some(0.5)];
968        let result = func.apply(&values);
969        assert!(result.is_ok());
970        assert_eq!(result.ok().flatten(), Some(20.0));
971
972        // Nearest neighbor
973        let values = vec![Some(0.7)];
974        let result = func.apply(&values);
975        assert!(result.is_ok());
976        assert_eq!(result.ok().flatten(), Some(20.0)); // Closest to 0.5
977    }
978
979    #[test]
980    fn test_pixel_function_lookup_table_linear() {
981        let func = PixelFunction::LookupTable {
982            table: vec![(0.0, 10.0), (1.0, 30.0)],
983            interpolation: "linear".to_string(),
984        };
985
986        // Interpolated value
987        let values = vec![Some(0.5)];
988        let result = func.apply(&values);
989        assert!(result.is_ok());
990        assert_eq!(result.ok().flatten(), Some(20.0)); // Linear interpolation: 10 + 0.5 * (30-10)
991
992        // Edge case: below range
993        let values = vec![Some(-1.0)];
994        let result = func.apply(&values);
995        assert!(result.is_ok());
996        assert_eq!(result.ok().flatten(), Some(10.0));
997
998        // Edge case: above range
999        let values = vec![Some(2.0)];
1000        let result = func.apply(&values);
1001        assert!(result.is_ok());
1002        assert_eq!(result.ok().flatten(), Some(30.0));
1003    }
1004
1005    #[test]
1006    fn test_pixel_function_conditional() {
1007        let func = PixelFunction::Conditional {
1008            condition: "B1 > 0.5".to_string(),
1009            value_if_true: "B1 * 2".to_string(),
1010            value_if_false: "B1 / 2".to_string(),
1011        };
1012
1013        // True case
1014        let values_true = vec![Some(0.8)];
1015        let result = func.apply(&values_true);
1016        assert!(result.is_ok());
1017        assert_eq!(result.ok().flatten(), Some(1.6));
1018
1019        // False case
1020        let values_false = vec![Some(0.3)];
1021        let result = func.apply(&values_false);
1022        assert!(result.is_ok());
1023        assert_eq!(result.ok().flatten(), Some(0.15));
1024    }
1025
1026    #[test]
1027    fn test_pixel_function_multiply() {
1028        let func = PixelFunction::Multiply;
1029
1030        let values = vec![Some(2.0), Some(3.0), Some(4.0)];
1031        let result = func.apply(&values);
1032        assert!(result.is_ok());
1033        assert_eq!(result.ok().flatten(), Some(24.0));
1034
1035        // With NoData
1036        let values_nodata = vec![Some(2.0), None, Some(4.0)];
1037        let result = func.apply(&values_nodata);
1038        assert!(result.is_ok());
1039        assert_eq!(result.ok().flatten(), Some(8.0)); // Only multiplies valid values
1040    }
1041
1042    #[test]
1043    fn test_pixel_function_divide() {
1044        let func = PixelFunction::Divide;
1045
1046        let values = vec![Some(10.0), Some(2.0)];
1047        let result = func.apply(&values);
1048        assert!(result.is_ok());
1049        assert_eq!(result.ok().flatten(), Some(5.0));
1050
1051        // Division by zero
1052        let values_zero = vec![Some(10.0), Some(0.0)];
1053        let result = func.apply(&values_zero);
1054        assert!(result.is_ok());
1055        assert_eq!(result.ok().flatten(), None);
1056    }
1057
1058    #[test]
1059    fn test_pixel_function_square_root() {
1060        let func = PixelFunction::SquareRoot;
1061
1062        let values = vec![Some(25.0)];
1063        let result = func.apply(&values);
1064        assert!(result.is_ok());
1065        assert_eq!(result.ok().flatten(), Some(5.0));
1066
1067        // Negative value
1068        let values_neg = vec![Some(-4.0)];
1069        let result = func.apply(&values_neg);
1070        assert!(result.is_ok());
1071        assert_eq!(result.ok().flatten(), None);
1072    }
1073
1074    #[test]
1075    fn test_pixel_function_absolute() {
1076        let func = PixelFunction::Absolute;
1077
1078        let values_pos = vec![Some(5.0)];
1079        let result = func.apply(&values_pos);
1080        assert!(result.is_ok());
1081        assert_eq!(result.ok().flatten(), Some(5.0));
1082
1083        let values_neg = vec![Some(-5.0)];
1084        let result = func.apply(&values_neg);
1085        assert!(result.is_ok());
1086        assert_eq!(result.ok().flatten(), Some(5.0));
1087    }
1088
1089    #[test]
1090    fn test_pixel_function_validation() {
1091        // NDVI validation
1092        let ndvi_func = PixelFunction::Ndvi;
1093        let source = VrtSource::new(SourceFilename::absolute("/test.tif"), 1);
1094        let sources_valid = vec![source.clone(), source.clone()];
1095        assert!(ndvi_func.validate(&sources_valid).is_ok());
1096
1097        let sources_invalid = vec![source.clone()];
1098        assert!(ndvi_func.validate(&sources_invalid).is_err());
1099
1100        // BandMath validation
1101        let math_func = PixelFunction::BandMath {
1102            expression: "B1 + B2".to_string(),
1103        };
1104        assert!(math_func.validate(&sources_valid).is_ok());
1105
1106        let empty_expr = PixelFunction::BandMath {
1107            expression: "".to_string(),
1108        };
1109        assert!(empty_expr.validate(&sources_valid).is_err());
1110
1111        // LookupTable validation
1112        let lut_func = PixelFunction::LookupTable {
1113            table: vec![(0.0, 10.0)],
1114            interpolation: "linear".to_string(),
1115        };
1116        assert!(lut_func.validate(&sources_valid).is_ok());
1117
1118        let empty_lut = PixelFunction::LookupTable {
1119            table: vec![],
1120            interpolation: "linear".to_string(),
1121        };
1122        assert!(empty_lut.validate(&sources_valid).is_err());
1123    }
1124}