Skip to main content

oxigdal_vrt/
band.rs

1//! VRT virtual band configuration
2
3use crate::error::{Result, VrtError};
4use crate::source::VrtSource;
5use oxigdal_core::types::{ColorInterpretation, NoDataValue, RasterDataType};
6use serde::{Deserialize, Serialize};
7
8/// VRT band configuration
9#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
10pub struct VrtBand {
11    /// Band number (1-based)
12    pub band: usize,
13    /// Data type
14    pub data_type: RasterDataType,
15    /// Color interpretation
16    pub color_interp: ColorInterpretation,
17    /// NoData value
18    pub nodata: NoDataValue,
19    /// Sources for this band
20    pub sources: Vec<VrtSource>,
21    /// Block size (tile dimensions)
22    pub block_size: Option<(u32, u32)>,
23    /// Pixel function (for on-the-fly computation)
24    pub pixel_function: Option<PixelFunction>,
25    /// Offset for scaling
26    pub offset: Option<f64>,
27    /// Scale factor
28    pub scale: Option<f64>,
29    /// Color table
30    pub color_table: Option<ColorTable>,
31}
32
33impl VrtBand {
34    /// Creates a new VRT band
35    pub fn new(band: usize, data_type: RasterDataType) -> Self {
36        Self {
37            band,
38            data_type,
39            color_interp: ColorInterpretation::Undefined,
40            nodata: NoDataValue::None,
41            sources: Vec::new(),
42            block_size: None,
43            pixel_function: None,
44            offset: None,
45            scale: None,
46            color_table: None,
47        }
48    }
49
50    /// Creates a simple band with a single source
51    pub fn simple(band: usize, data_type: RasterDataType, source: VrtSource) -> Self {
52        Self {
53            band,
54            data_type,
55            color_interp: ColorInterpretation::Undefined,
56            nodata: source.nodata.unwrap_or(NoDataValue::None),
57            sources: vec![source],
58            block_size: None,
59            pixel_function: None,
60            offset: None,
61            scale: None,
62            color_table: None,
63        }
64    }
65
66    /// Adds a source to this band
67    pub fn add_source(&mut self, source: VrtSource) {
68        self.sources.push(source);
69    }
70
71    /// Sets the color interpretation
72    pub fn with_color_interp(mut self, color_interp: ColorInterpretation) -> Self {
73        self.color_interp = color_interp;
74        self
75    }
76
77    /// Sets the NoData value
78    pub fn with_nodata(mut self, nodata: NoDataValue) -> Self {
79        self.nodata = nodata;
80        self
81    }
82
83    /// Sets the block size
84    pub fn with_block_size(mut self, width: u32, height: u32) -> Self {
85        self.block_size = Some((width, height));
86        self
87    }
88
89    /// Sets the pixel function
90    pub fn with_pixel_function(mut self, function: PixelFunction) -> Self {
91        self.pixel_function = Some(function);
92        self
93    }
94
95    /// Sets the offset and scale
96    pub fn with_scaling(mut self, offset: f64, scale: f64) -> Self {
97        self.offset = Some(offset);
98        self.scale = Some(scale);
99        self
100    }
101
102    /// Sets the color table
103    pub fn with_color_table(mut self, color_table: ColorTable) -> Self {
104        self.color_table = Some(color_table);
105        self
106    }
107
108    /// Validates the band configuration
109    ///
110    /// # Errors
111    /// Returns an error if the band is invalid
112    pub fn validate(&self) -> Result<()> {
113        if self.band == 0 {
114            return Err(VrtError::invalid_band("Band number must be >= 1"));
115        }
116
117        if self.sources.is_empty() && self.pixel_function.is_none() {
118            return Err(VrtError::invalid_band(
119                "Band must have at least one source or a pixel function",
120            ));
121        }
122
123        // Validate all sources
124        for source in &self.sources {
125            source.validate()?;
126        }
127
128        // Validate pixel function if present
129        if let Some(ref func) = self.pixel_function {
130            func.validate(&self.sources)?;
131        }
132
133        Ok(())
134    }
135
136    /// Checks if this band has multiple sources
137    pub fn has_multiple_sources(&self) -> bool {
138        self.sources.len() > 1
139    }
140
141    /// Checks if this band uses a pixel function
142    pub fn uses_pixel_function(&self) -> bool {
143        self.pixel_function.is_some()
144    }
145
146    /// Applies scaling to a value
147    pub fn apply_scaling(&self, value: f64) -> f64 {
148        let scaled = if let Some(scale) = self.scale {
149            value * scale
150        } else {
151            value
152        };
153
154        if let Some(offset) = self.offset {
155            scaled + offset
156        } else {
157            scaled
158        }
159    }
160}
161
162/// Pixel function for on-the-fly computation
163#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
164pub enum PixelFunction {
165    /// Average of all sources
166    Average,
167    /// Minimum of all sources
168    Min,
169    /// Maximum of all sources
170    Max,
171    /// Sum of all sources
172    Sum,
173    /// First valid (non-NoData) value
174    FirstValid,
175    /// Last valid (non-NoData) value
176    LastValid,
177    /// Weighted average (requires weights)
178    WeightedAverage {
179        /// Weights for each source
180        weights: Vec<f64>,
181    },
182    /// NDVI: (NIR - Red) / (NIR + Red)
183    /// Requires exactly 2 sources: [Red, NIR]
184    Ndvi,
185    /// Enhanced Vegetation Index: 2.5 * (NIR - Red) / (NIR + 6*Red - 7.5*Blue + 1)
186    /// Requires exactly 3 sources: [Red, NIR, Blue]
187    Evi,
188    /// Normalized Difference Water Index: (Green - NIR) / (Green + NIR)
189    /// Requires exactly 2 sources: [Green, NIR]
190    Ndwi,
191    /// Band math expression
192    /// Supports operations: +, -, *, /, sqrt, pow, abs, min, max
193    /// Variables are named as B1, B2, B3, etc. corresponding to source bands
194    BandMath {
195        /// Expression string (e.g., "(B1 + B2) / 2", "sqrt(B1 * B2)")
196        expression: String,
197    },
198    /// Lookup table transformation
199    /// Maps input values to output values
200    LookupTable {
201        /// Lookup table: vec of (input_value, output_value) pairs
202        table: Vec<(f64, f64)>,
203        /// Interpolation method: "nearest", "linear"
204        interpolation: String,
205    },
206    /// Conditional logic: if condition then value_if_true else value_if_false
207    /// Condition format: "B1 > 0.5", "B1 >= B2", etc.
208    Conditional {
209        /// Condition expression
210        condition: String,
211        /// Value if condition is true (can be a constant or expression)
212        value_if_true: String,
213        /// Value if condition is false (can be a constant or expression)
214        value_if_false: String,
215    },
216    /// Multiply source values
217    Multiply,
218    /// Divide first source by second (handles division by zero)
219    Divide,
220    /// Square root of source value
221    SquareRoot,
222    /// Absolute value of source value
223    Absolute,
224    /// Custom function (not yet implemented)
225    Custom {
226        /// Function name
227        name: String,
228    },
229}
230
231impl PixelFunction {
232    /// Validates the pixel function against sources
233    ///
234    /// # Errors
235    /// Returns an error if the function is invalid for the given sources
236    pub fn validate(&self, sources: &[VrtSource]) -> Result<()> {
237        match self {
238            Self::WeightedAverage { weights } => {
239                if weights.len() != sources.len() {
240                    return Err(VrtError::invalid_band(format!(
241                        "WeightedAverage requires {} weights, got {}",
242                        sources.len(),
243                        weights.len()
244                    )));
245                }
246
247                // Check that weights sum to approximately 1.0
248                let sum: f64 = weights.iter().sum();
249                if (sum - 1.0).abs() > 0.001 {
250                    return Err(VrtError::invalid_band(format!(
251                        "Weights should sum to 1.0, got {}",
252                        sum
253                    )));
254                }
255            }
256            Self::Ndvi | Self::Ndwi => {
257                if sources.len() != 2 {
258                    return Err(VrtError::invalid_band(format!(
259                        "{:?} requires exactly 2 sources, got {}",
260                        self,
261                        sources.len()
262                    )));
263                }
264            }
265            Self::Evi => {
266                if sources.len() != 3 {
267                    return Err(VrtError::invalid_band(format!(
268                        "EVI requires exactly 3 sources, got {}",
269                        sources.len()
270                    )));
271                }
272            }
273            Self::BandMath { expression } => {
274                if expression.trim().is_empty() {
275                    return Err(VrtError::invalid_band(
276                        "BandMath expression cannot be empty",
277                    ));
278                }
279            }
280            Self::LookupTable { table, .. } => {
281                if table.is_empty() {
282                    return Err(VrtError::invalid_band("LookupTable cannot be empty"));
283                }
284            }
285            Self::Conditional {
286                condition,
287                value_if_true,
288                value_if_false,
289            } => {
290                if condition.trim().is_empty() {
291                    return Err(VrtError::invalid_band(
292                        "Conditional condition cannot be empty",
293                    ));
294                }
295                if value_if_true.trim().is_empty() || value_if_false.trim().is_empty() {
296                    return Err(VrtError::invalid_band("Conditional values cannot be empty"));
297                }
298            }
299            Self::Divide | Self::Multiply => {
300                if sources.len() < 2 {
301                    return Err(VrtError::invalid_band(format!(
302                        "{:?} requires at least 2 sources, got {}",
303                        self,
304                        sources.len()
305                    )));
306                }
307            }
308            Self::SquareRoot | Self::Absolute => {
309                if sources.is_empty() {
310                    return Err(VrtError::invalid_band(format!(
311                        "{:?} requires at least 1 source",
312                        self
313                    )));
314                }
315            }
316            Self::Custom { name } => {
317                return Err(VrtError::InvalidPixelFunction {
318                    function: name.clone(),
319                });
320            }
321            _ => {}
322        }
323        Ok(())
324    }
325
326    /// Applies the pixel function to a set of values
327    ///
328    /// # Errors
329    /// Returns an error if the function cannot be applied
330    pub fn apply(&self, values: &[Option<f64>]) -> Result<Option<f64>> {
331        match self {
332            Self::Average => {
333                let valid: Vec<f64> = values.iter().filter_map(|v| *v).collect();
334                if valid.is_empty() {
335                    Ok(None)
336                } else {
337                    Ok(Some(valid.iter().sum::<f64>() / valid.len() as f64))
338                }
339            }
340            Self::Min => Ok(values
341                .iter()
342                .filter_map(|v| *v)
343                .min_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))),
344            Self::Max => Ok(values
345                .iter()
346                .filter_map(|v| *v)
347                .max_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))),
348            Self::Sum => {
349                let valid: Vec<f64> = values.iter().filter_map(|v| *v).collect();
350                if valid.is_empty() {
351                    Ok(None)
352                } else {
353                    Ok(Some(valid.iter().sum()))
354                }
355            }
356            Self::FirstValid => Ok(values.iter().find_map(|v| *v)),
357            Self::LastValid => Ok(values.iter().rev().find_map(|v| *v)),
358            Self::WeightedAverage { weights } => {
359                if weights.len() != values.len() {
360                    return Err(VrtError::invalid_band("Weight count mismatch"));
361                }
362
363                let mut sum = 0.0;
364                let mut weight_sum = 0.0;
365
366                for (value, weight) in values.iter().zip(weights.iter()) {
367                    if let Some(v) = value {
368                        sum += v * weight;
369                        weight_sum += weight;
370                    }
371                }
372
373                if weight_sum > 0.0 {
374                    Ok(Some(sum / weight_sum))
375                } else {
376                    Ok(None)
377                }
378            }
379            Self::Ndvi => {
380                // NDVI = (NIR - Red) / (NIR + Red)
381                if values.len() != 2 {
382                    return Err(VrtError::invalid_band("NDVI requires exactly 2 values"));
383                }
384                match (values[0], values[1]) {
385                    (Some(red), Some(nir)) => {
386                        let denominator = nir + red;
387                        if denominator.abs() < f64::EPSILON {
388                            Ok(None) // Avoid division by zero
389                        } else {
390                            Ok(Some((nir - red) / denominator))
391                        }
392                    }
393                    _ => Ok(None),
394                }
395            }
396            Self::Evi => {
397                // EVI = 2.5 * (NIR - Red) / (NIR + 6*Red - 7.5*Blue + 1)
398                if values.len() != 3 {
399                    return Err(VrtError::invalid_band("EVI requires exactly 3 values"));
400                }
401                match (values[0], values[1], values[2]) {
402                    (Some(red), Some(nir), Some(blue)) => {
403                        let denominator = nir + 6.0 * red - 7.5 * blue + 1.0;
404                        if denominator.abs() < f64::EPSILON {
405                            Ok(None)
406                        } else {
407                            Ok(Some(2.5 * (nir - red) / denominator))
408                        }
409                    }
410                    _ => Ok(None),
411                }
412            }
413            Self::Ndwi => {
414                // NDWI = (Green - NIR) / (Green + NIR)
415                if values.len() != 2 {
416                    return Err(VrtError::invalid_band("NDWI requires exactly 2 values"));
417                }
418                match (values[0], values[1]) {
419                    (Some(green), Some(nir)) => {
420                        let denominator = green + nir;
421                        if denominator.abs() < f64::EPSILON {
422                            Ok(None)
423                        } else {
424                            Ok(Some((green - nir) / denominator))
425                        }
426                    }
427                    _ => Ok(None),
428                }
429            }
430            Self::BandMath { expression } => Self::evaluate_expression(expression, values),
431            Self::LookupTable {
432                table,
433                interpolation,
434            } => {
435                if values.is_empty() {
436                    return Ok(None);
437                }
438                if let Some(value) = values[0] {
439                    Self::apply_lookup_table(value, table, interpolation)
440                } else {
441                    Ok(None)
442                }
443            }
444            Self::Conditional {
445                condition,
446                value_if_true,
447                value_if_false,
448            } => Self::evaluate_conditional(condition, value_if_true, value_if_false, values),
449            Self::Multiply => {
450                let valid: Vec<f64> = values.iter().filter_map(|v| *v).collect();
451                if valid.is_empty() {
452                    Ok(None)
453                } else {
454                    Ok(Some(valid.iter().product()))
455                }
456            }
457            Self::Divide => {
458                if values.len() < 2 {
459                    return Err(VrtError::invalid_band("Divide requires at least 2 values"));
460                }
461                match (values[0], values[1]) {
462                    (Some(numerator), Some(denominator)) => {
463                        if denominator.abs() < f64::EPSILON {
464                            Ok(None) // Avoid division by zero
465                        } else {
466                            Ok(Some(numerator / denominator))
467                        }
468                    }
469                    _ => Ok(None),
470                }
471            }
472            Self::SquareRoot => {
473                if values.is_empty() {
474                    return Ok(None);
475                }
476                values[0].map_or(Ok(None), |v| {
477                    if v < 0.0 {
478                        Ok(None) // Negative values have no real square root
479                    } else {
480                        Ok(Some(v.sqrt()))
481                    }
482                })
483            }
484            Self::Absolute => {
485                if values.is_empty() {
486                    return Ok(None);
487                }
488                Ok(values[0].map(|v| v.abs()))
489            }
490            Self::Custom { name } => Err(VrtError::InvalidPixelFunction {
491                function: name.clone(),
492            }),
493        }
494    }
495
496    /// Evaluates a band math expression
497    fn evaluate_expression(expression: &str, values: &[Option<f64>]) -> Result<Option<f64>> {
498        // Simple expression evaluator for basic band math
499        // Supports: +, -, *, /, sqrt, pow, abs, min, max
500        // Variables: B1, B2, B3, etc.
501
502        let mut expr = expression.to_string();
503
504        // Replace band variables with actual values
505        for (i, value) in values.iter().enumerate() {
506            let var = format!("B{}", i + 1);
507            if let Some(v) = value {
508                expr = expr.replace(&var, &v.to_string());
509            } else {
510                return Ok(None); // If any band is NoData, result is NoData
511            }
512        }
513
514        // Basic expression evaluation (simplified)
515        // For production, consider using a proper expression parser like `evalexpr`
516        match Self::simple_eval(&expr) {
517            Ok(result) => Ok(Some(result)),
518            Err(_) => Err(VrtError::invalid_band(format!(
519                "Failed to evaluate expression: {}",
520                expression
521            ))),
522        }
523    }
524
525    /// Simple expression evaluator (basic implementation)
526    fn simple_eval(expr: &str) -> Result<f64> {
527        let expr = expr.trim();
528
529        // Try to parse as number first
530        if let Ok(num) = expr.parse::<f64>() {
531            return Ok(num);
532        }
533
534        // Handle sqrt
535        if expr.starts_with("sqrt(") && expr.ends_with(')') {
536            let inner = &expr[5..expr.len() - 1];
537            let val = Self::simple_eval(inner)?;
538            if val < 0.0 {
539                return Err(VrtError::invalid_band("Square root of negative number"));
540            }
541            return Ok(val.sqrt());
542        }
543
544        // Handle abs
545        if expr.starts_with("abs(") && expr.ends_with(')') {
546            let inner = &expr[4..expr.len() - 1];
547            let val = Self::simple_eval(inner)?;
548            return Ok(val.abs());
549        }
550
551        // Handle parentheses
552        if expr.starts_with('(') && expr.ends_with(')') {
553            // Check if these are balanced outer parens
554            let inner = &expr[1..expr.len() - 1];
555            let mut depth = 0;
556            let mut is_outer = true;
557            for ch in inner.chars() {
558                if ch == '(' {
559                    depth += 1;
560                } else if ch == ')' {
561                    depth -= 1;
562                    if depth < 0 {
563                        is_outer = false;
564                        break;
565                    }
566                }
567            }
568            if is_outer && depth == 0 {
569                return Self::simple_eval(inner);
570            }
571        }
572
573        // Handle binary operations (search from right to left for left-to-right evaluation)
574        // Process + and - first (lower precedence), then * and / (higher precedence)
575        for op in &['+', '-'] {
576            let mut depth = 0;
577            for (i, ch) in expr.char_indices().rev() {
578                if ch == ')' {
579                    depth += 1;
580                } else if ch == '(' {
581                    depth -= 1;
582                } else if depth == 0 && ch == *op && i > 0 && i < expr.len() - 1 {
583                    let left = Self::simple_eval(&expr[..i])?;
584                    let right = Self::simple_eval(&expr[i + 1..])?;
585                    return match op {
586                        '+' => Ok(left + right),
587                        '-' => Ok(left - right),
588                        _ => unreachable!(),
589                    };
590                }
591            }
592        }
593
594        for op in &['*', '/'] {
595            let mut depth = 0;
596            for (i, ch) in expr.char_indices().rev() {
597                if ch == ')' {
598                    depth += 1;
599                } else if ch == '(' {
600                    depth -= 1;
601                } else if depth == 0 && ch == *op && i > 0 && i < expr.len() - 1 {
602                    let left = Self::simple_eval(&expr[..i])?;
603                    let right = Self::simple_eval(&expr[i + 1..])?;
604                    return match op {
605                        '*' => Ok(left * right),
606                        '/' => {
607                            if right.abs() < f64::EPSILON {
608                                Err(VrtError::invalid_band("Division by zero"))
609                            } else {
610                                Ok(left / right)
611                            }
612                        }
613                        _ => unreachable!(),
614                    };
615                }
616            }
617        }
618
619        Err(VrtError::invalid_band(format!(
620            "Cannot parse expression: {}",
621            expr
622        )))
623    }
624
625    /// Applies lookup table transformation
626    fn apply_lookup_table(
627        value: f64,
628        table: &[(f64, f64)],
629        interpolation: &str,
630    ) -> Result<Option<f64>> {
631        if table.is_empty() {
632            return Ok(None);
633        }
634
635        match interpolation {
636            "nearest" => {
637                // Find nearest entry
638                let mut best_idx = 0;
639                let mut best_dist = (table[0].0 - value).abs();
640
641                for (i, (input, _)) in table.iter().enumerate() {
642                    let dist = (input - value).abs();
643                    if dist < best_dist {
644                        best_dist = dist;
645                        best_idx = i;
646                    }
647                }
648
649                Ok(Some(table[best_idx].1))
650            }
651            "linear" => {
652                // Linear interpolation
653                if value <= table[0].0 {
654                    return Ok(Some(table[0].1));
655                }
656                if value >= table[table.len() - 1].0 {
657                    return Ok(Some(table[table.len() - 1].1));
658                }
659
660                // Find surrounding points
661                for i in 0..table.len() - 1 {
662                    if value >= table[i].0 && value <= table[i + 1].0 {
663                        let x0 = table[i].0;
664                        let y0 = table[i].1;
665                        let x1 = table[i + 1].0;
666                        let y1 = table[i + 1].1;
667
668                        let t = (value - x0) / (x1 - x0);
669                        return Ok(Some(y0 + t * (y1 - y0)));
670                    }
671                }
672
673                Ok(Some(table[0].1))
674            }
675            _ => Err(VrtError::invalid_band(format!(
676                "Unknown interpolation method: {}",
677                interpolation
678            ))),
679        }
680    }
681
682    /// Evaluates conditional expression
683    fn evaluate_conditional(
684        condition: &str,
685        value_if_true: &str,
686        value_if_false: &str,
687        values: &[Option<f64>],
688    ) -> Result<Option<f64>> {
689        // Simple condition evaluator
690        // Supports: >, <, >=, <=, ==, !=
691
692        let cond_result = Self::evaluate_condition(condition, values)?;
693
694        let target_expr = if cond_result {
695            value_if_true
696        } else {
697            value_if_false
698        };
699
700        // Evaluate the target expression
701        Self::evaluate_expression(target_expr, values)
702    }
703
704    /// Evaluates a boolean condition
705    fn evaluate_condition(condition: &str, values: &[Option<f64>]) -> Result<bool> {
706        let condition = condition.trim();
707
708        // Try different comparison operators in order (longer ones first to avoid false matches)
709        let operators = [">=", "<=", "==", "!=", ">", "<"];
710
711        for op_str in &operators {
712            if let Some(pos) = condition.find(op_str) {
713                let left_expr = condition[..pos].trim();
714                let right_expr = condition[pos + op_str.len()..].trim();
715
716                let left_val = Self::evaluate_expression(left_expr, values)?
717                    .ok_or_else(|| VrtError::invalid_band("Left side of condition is NoData"))?;
718
719                let right_val = Self::evaluate_expression(right_expr, values)?
720                    .ok_or_else(|| VrtError::invalid_band("Right side of condition is NoData"))?;
721
722                let result = match *op_str {
723                    ">=" => left_val >= right_val,
724                    "<=" => left_val <= right_val,
725                    ">" => left_val > right_val,
726                    "<" => left_val < right_val,
727                    "==" => (left_val - right_val).abs() < f64::EPSILON,
728                    "!=" => (left_val - right_val).abs() >= f64::EPSILON,
729                    _ => {
730                        return Err(VrtError::invalid_band(format!(
731                            "Unknown operator: {}",
732                            op_str
733                        )));
734                    }
735                };
736
737                return Ok(result);
738            }
739        }
740
741        Err(VrtError::invalid_band(format!(
742            "Cannot parse condition: {}",
743            condition
744        )))
745    }
746}
747
748/// Color table entry
749#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
750pub struct ColorEntry {
751    /// Color value (index)
752    pub value: u16,
753    /// Red component (0-255)
754    pub r: u8,
755    /// Green component (0-255)
756    pub g: u8,
757    /// Blue component (0-255)
758    pub b: u8,
759    /// Alpha component (0-255)
760    pub a: u8,
761}
762
763impl ColorEntry {
764    /// Creates a new color entry
765    pub const fn new(value: u16, r: u8, g: u8, b: u8, a: u8) -> Self {
766        Self { value, r, g, b, a }
767    }
768
769    /// Creates an opaque color entry
770    pub const fn rgb(value: u16, r: u8, g: u8, b: u8) -> Self {
771        Self::new(value, r, g, b, 255)
772    }
773}
774
775/// Color table (palette)
776#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
777pub struct ColorTable {
778    /// Color entries
779    pub entries: Vec<ColorEntry>,
780}
781
782impl ColorTable {
783    /// Creates a new empty color table
784    pub fn new() -> Self {
785        Self {
786            entries: Vec::new(),
787        }
788    }
789
790    /// Creates a color table with entries
791    pub fn with_entries(entries: Vec<ColorEntry>) -> Self {
792        Self { entries }
793    }
794
795    /// Adds a color entry
796    pub fn add_entry(&mut self, entry: ColorEntry) {
797        self.entries.push(entry);
798    }
799
800    /// Gets a color entry by value
801    pub fn get(&self, value: u16) -> Option<&ColorEntry> {
802        self.entries.iter().find(|e| e.value == value)
803    }
804}
805
806impl Default for ColorTable {
807    fn default() -> Self {
808        Self::new()
809    }
810}
811
812#[cfg(test)]
813mod tests {
814    use super::*;
815    use crate::source::SourceFilename;
816
817    #[test]
818    fn test_vrt_band_creation() {
819        let band = VrtBand::new(1, RasterDataType::UInt8);
820        assert_eq!(band.band, 1);
821        assert_eq!(band.data_type, RasterDataType::UInt8);
822    }
823
824    #[test]
825    fn test_vrt_band_validation() {
826        let source = VrtSource::new(SourceFilename::absolute("/test.tif"), 1);
827        let band = VrtBand::simple(1, RasterDataType::UInt8, source);
828        assert!(band.validate().is_ok());
829
830        let invalid_band = VrtBand::new(0, RasterDataType::UInt8);
831        assert!(invalid_band.validate().is_err());
832
833        let no_source_band = VrtBand::new(1, RasterDataType::UInt8);
834        assert!(no_source_band.validate().is_err());
835    }
836
837    #[test]
838    fn test_pixel_function_average() {
839        let func = PixelFunction::Average;
840        let values = vec![Some(1.0), Some(2.0), Some(3.0)];
841        let result = func.apply(&values);
842        assert!(result.is_ok());
843        assert_eq!(result.ok().flatten(), Some(2.0));
844
845        let with_none = vec![Some(1.0), None, Some(3.0)];
846        let result = func.apply(&with_none);
847        assert!(result.is_ok());
848        assert_eq!(result.ok().flatten(), Some(2.0));
849    }
850
851    #[test]
852    fn test_pixel_function_first_valid() {
853        let func = PixelFunction::FirstValid;
854        let values = vec![None, Some(2.0), Some(3.0)];
855        let result = func.apply(&values);
856        assert!(result.is_ok());
857        assert_eq!(result.ok().flatten(), Some(2.0));
858    }
859
860    #[test]
861    fn test_pixel_function_weighted_average() {
862        let func = PixelFunction::WeightedAverage {
863            weights: vec![0.5, 0.3, 0.2],
864        };
865        let values = vec![Some(10.0), Some(20.0), Some(30.0)];
866        let result = func.apply(&values);
867        assert!(result.is_ok());
868        // 10*0.5 + 20*0.3 + 30*0.2 = 5 + 6 + 6 = 17
869        assert_eq!(result.ok().flatten(), Some(17.0));
870    }
871
872    #[test]
873    fn test_band_scaling() {
874        let band = VrtBand::new(1, RasterDataType::UInt8).with_scaling(10.0, 2.0);
875
876        assert_eq!(band.apply_scaling(5.0), 20.0); // 5 * 2 + 10 = 20
877    }
878
879    #[test]
880    fn test_color_table() {
881        let mut table = ColorTable::new();
882        table.add_entry(ColorEntry::rgb(0, 255, 0, 0));
883        table.add_entry(ColorEntry::rgb(1, 0, 255, 0));
884        table.add_entry(ColorEntry::rgb(2, 0, 0, 255));
885
886        assert_eq!(table.entries.len(), 3);
887        assert_eq!(table.get(1).map(|e| e.g), Some(255));
888    }
889
890    #[test]
891    fn test_pixel_function_ndvi() {
892        let func = PixelFunction::Ndvi;
893
894        // Standard NDVI calculation
895        let values = vec![Some(0.1), Some(0.5)]; // Red, NIR
896        let result = func.apply(&values);
897        assert!(result.is_ok());
898        // NDVI = (0.5 - 0.1) / (0.5 + 0.1) = 0.4 / 0.6 = 0.666...
899        assert!((result.ok().flatten().expect("Should have value") - 0.666666).abs() < 0.001);
900
901        // With NoData
902        let values_nodata = vec![Some(0.1), None];
903        let result = func.apply(&values_nodata);
904        assert!(result.is_ok());
905        assert_eq!(result.ok().flatten(), None);
906
907        // Zero sum (edge case)
908        let values_zero = vec![Some(0.5), Some(-0.5)];
909        let result = func.apply(&values_zero);
910        assert!(result.is_ok());
911        assert_eq!(result.ok().flatten(), None); // Should return None to avoid division by zero
912    }
913
914    #[test]
915    fn test_pixel_function_evi() {
916        let func = PixelFunction::Evi;
917
918        // Standard EVI calculation
919        let values = vec![Some(0.1), Some(0.5), Some(0.05)]; // Red, NIR, Blue
920        let result = func.apply(&values);
921        assert!(result.is_ok());
922        // EVI = 2.5 * (0.5 - 0.1) / (0.5 + 6*0.1 - 7.5*0.05 + 1)
923        // = 2.5 * 0.4 / (0.5 + 0.6 - 0.375 + 1)
924        // = 1.0 / 1.725 = 0.5797...
925        let expected = 1.0 / 1.725;
926        assert!((result.ok().flatten().expect("Should have value") - expected).abs() < 0.001);
927    }
928
929    #[test]
930    fn test_pixel_function_ndwi() {
931        let func = PixelFunction::Ndwi;
932
933        // Standard NDWI calculation
934        let values = vec![Some(0.3), Some(0.2)]; // Green, NIR
935        let result = func.apply(&values);
936        assert!(result.is_ok());
937        // NDWI = (0.3 - 0.2) / (0.3 + 0.2) = 0.1 / 0.5 = 0.2
938        assert!((result.ok().flatten().expect("Should have value") - 0.2).abs() < 0.001);
939    }
940
941    #[test]
942    fn test_pixel_function_band_math() {
943        let func = PixelFunction::BandMath {
944            expression: "(B1 + B2) / 2".to_string(),
945        };
946
947        let values = vec![Some(10.0), Some(20.0)];
948        let result = func.apply(&values);
949        assert!(result.is_ok());
950        assert_eq!(result.ok().flatten(), Some(15.0));
951
952        // Test with sqrt
953        let func_sqrt = PixelFunction::BandMath {
954            expression: "sqrt(B1)".to_string(),
955        };
956        let values_sqrt = vec![Some(16.0)];
957        let result = func_sqrt.apply(&values_sqrt);
958        assert!(result.is_ok());
959        assert_eq!(result.ok().flatten(), Some(4.0));
960
961        // Test with abs
962        let func_abs = PixelFunction::BandMath {
963            expression: "abs(B1)".to_string(),
964        };
965        let values_abs = vec![Some(-5.0)];
966        let result = func_abs.apply(&values_abs);
967        assert!(result.is_ok());
968        assert_eq!(result.ok().flatten(), Some(5.0));
969    }
970
971    #[test]
972    fn test_pixel_function_lookup_table_nearest() {
973        let func = PixelFunction::LookupTable {
974            table: vec![(0.0, 10.0), (0.5, 20.0), (1.0, 30.0)],
975            interpolation: "nearest".to_string(),
976        };
977
978        // Exact match
979        let values = vec![Some(0.5)];
980        let result = func.apply(&values);
981        assert!(result.is_ok());
982        assert_eq!(result.ok().flatten(), Some(20.0));
983
984        // Nearest neighbor
985        let values = vec![Some(0.7)];
986        let result = func.apply(&values);
987        assert!(result.is_ok());
988        assert_eq!(result.ok().flatten(), Some(20.0)); // Closest to 0.5
989    }
990
991    #[test]
992    fn test_pixel_function_lookup_table_linear() {
993        let func = PixelFunction::LookupTable {
994            table: vec![(0.0, 10.0), (1.0, 30.0)],
995            interpolation: "linear".to_string(),
996        };
997
998        // Interpolated value
999        let values = vec![Some(0.5)];
1000        let result = func.apply(&values);
1001        assert!(result.is_ok());
1002        assert_eq!(result.ok().flatten(), Some(20.0)); // Linear interpolation: 10 + 0.5 * (30-10)
1003
1004        // Edge case: below range
1005        let values = vec![Some(-1.0)];
1006        let result = func.apply(&values);
1007        assert!(result.is_ok());
1008        assert_eq!(result.ok().flatten(), Some(10.0));
1009
1010        // Edge case: above range
1011        let values = vec![Some(2.0)];
1012        let result = func.apply(&values);
1013        assert!(result.is_ok());
1014        assert_eq!(result.ok().flatten(), Some(30.0));
1015    }
1016
1017    #[test]
1018    fn test_pixel_function_conditional() {
1019        let func = PixelFunction::Conditional {
1020            condition: "B1 > 0.5".to_string(),
1021            value_if_true: "B1 * 2".to_string(),
1022            value_if_false: "B1 / 2".to_string(),
1023        };
1024
1025        // True case
1026        let values_true = vec![Some(0.8)];
1027        let result = func.apply(&values_true);
1028        assert!(result.is_ok());
1029        assert_eq!(result.ok().flatten(), Some(1.6));
1030
1031        // False case
1032        let values_false = vec![Some(0.3)];
1033        let result = func.apply(&values_false);
1034        assert!(result.is_ok());
1035        assert_eq!(result.ok().flatten(), Some(0.15));
1036    }
1037
1038    #[test]
1039    fn test_pixel_function_multiply() {
1040        let func = PixelFunction::Multiply;
1041
1042        let values = vec![Some(2.0), Some(3.0), Some(4.0)];
1043        let result = func.apply(&values);
1044        assert!(result.is_ok());
1045        assert_eq!(result.ok().flatten(), Some(24.0));
1046
1047        // With NoData
1048        let values_nodata = vec![Some(2.0), None, Some(4.0)];
1049        let result = func.apply(&values_nodata);
1050        assert!(result.is_ok());
1051        assert_eq!(result.ok().flatten(), Some(8.0)); // Only multiplies valid values
1052    }
1053
1054    #[test]
1055    fn test_pixel_function_divide() {
1056        let func = PixelFunction::Divide;
1057
1058        let values = vec![Some(10.0), Some(2.0)];
1059        let result = func.apply(&values);
1060        assert!(result.is_ok());
1061        assert_eq!(result.ok().flatten(), Some(5.0));
1062
1063        // Division by zero
1064        let values_zero = vec![Some(10.0), Some(0.0)];
1065        let result = func.apply(&values_zero);
1066        assert!(result.is_ok());
1067        assert_eq!(result.ok().flatten(), None);
1068    }
1069
1070    #[test]
1071    fn test_pixel_function_square_root() {
1072        let func = PixelFunction::SquareRoot;
1073
1074        let values = vec![Some(25.0)];
1075        let result = func.apply(&values);
1076        assert!(result.is_ok());
1077        assert_eq!(result.ok().flatten(), Some(5.0));
1078
1079        // Negative value
1080        let values_neg = vec![Some(-4.0)];
1081        let result = func.apply(&values_neg);
1082        assert!(result.is_ok());
1083        assert_eq!(result.ok().flatten(), None);
1084    }
1085
1086    #[test]
1087    fn test_pixel_function_absolute() {
1088        let func = PixelFunction::Absolute;
1089
1090        let values_pos = vec![Some(5.0)];
1091        let result = func.apply(&values_pos);
1092        assert!(result.is_ok());
1093        assert_eq!(result.ok().flatten(), Some(5.0));
1094
1095        let values_neg = vec![Some(-5.0)];
1096        let result = func.apply(&values_neg);
1097        assert!(result.is_ok());
1098        assert_eq!(result.ok().flatten(), Some(5.0));
1099    }
1100
1101    #[test]
1102    fn test_pixel_function_validation() {
1103        // NDVI validation
1104        let ndvi_func = PixelFunction::Ndvi;
1105        let source = VrtSource::new(SourceFilename::absolute("/test.tif"), 1);
1106        let sources_valid = vec![source.clone(), source.clone()];
1107        assert!(ndvi_func.validate(&sources_valid).is_ok());
1108
1109        let sources_invalid = vec![source.clone()];
1110        assert!(ndvi_func.validate(&sources_invalid).is_err());
1111
1112        // BandMath validation
1113        let math_func = PixelFunction::BandMath {
1114            expression: "B1 + B2".to_string(),
1115        };
1116        assert!(math_func.validate(&sources_valid).is_ok());
1117
1118        let empty_expr = PixelFunction::BandMath {
1119            expression: "".to_string(),
1120        };
1121        assert!(empty_expr.validate(&sources_valid).is_err());
1122
1123        // LookupTable validation
1124        let lut_func = PixelFunction::LookupTable {
1125            table: vec![(0.0, 10.0)],
1126            interpolation: "linear".to_string(),
1127        };
1128        assert!(lut_func.validate(&sources_valid).is_ok());
1129
1130        let empty_lut = PixelFunction::LookupTable {
1131            table: vec![],
1132            interpolation: "linear".to_string(),
1133        };
1134        assert!(empty_lut.validate(&sources_valid).is_err());
1135    }
1136}