#ifndef CERES_EXAMPLES_FIELDS_OF_EXPERTS_H_
#define CERES_EXAMPLES_FIELDS_OF_EXPERTS_H_
#include <iostream>
#include <vector>
#include "ceres/loss_function.h"
#include "ceres/cost_function.h"
#include "ceres/sized_cost_function.h"
#include "pgm_image.h"
namespace ceres {
namespace examples {
class FieldsOfExpertsCost : public ceres::CostFunction {
public:
explicit FieldsOfExpertsCost(const std::vector<double>& filter);
virtual bool Evaluate(double const* const* parameters,
double* residuals,
double** jacobians) const;
private:
const std::vector<double>& filter_;
};
class FieldsOfExpertsLoss : public ceres::LossFunction {
public:
explicit FieldsOfExpertsLoss(double alpha) : alpha_(alpha) { }
virtual void Evaluate(double, double*) const;
private:
const double alpha_;
};
class FieldsOfExperts {
public:
FieldsOfExperts();
bool LoadFromFile(const std::string& filename);
int Size() const {
return size_;
}
int NumVariables() const {
return size_ * size_;
}
int NumFilters() const {
return num_filters_;
}
ceres::CostFunction* NewCostFunction(int alpha_index) const;
ceres::LossFunction* NewLossFunction(int alpha_index) const;
const std::vector<int>& GetXDeltaIndices() const {
return x_delta_indices_;
}
const std::vector<int>& GetYDeltaIndices() const {
return y_delta_indices_;
}
private:
int size_;
int num_filters_;
std::vector<int> x_delta_indices_, y_delta_indices_;
std::vector<double> alpha_;
std::vector<std::vector<double> > filters_;
};
} }
#endif